摘要
apache arrow是用于在jvm和python进程之间进行高效数据传输的列式数据格式。
使用arrow
须首先将基于arrow的数据传输设置为可用。
import numpy as np import pandas as pd # 使用arrow进行数据传输 spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") pdf = pd.DataFrame(np.random.rand(100, 3)) df = spark.createDataFrame(pdf) result_pdf = df.select("*").toPandas()
pandas udf(向量化udf)
pandas udf基于arrow进行数据传输,基于pandas进行数据计算,支持向量化操作。使用pandas_udf装饰器即可将用户自定义函数封装为一个pandas udf。在spark3.0以后,推荐使用python 类型提示(type hints)进行pandas udf的定义。
通常情况下,输入或输出类型应为pandas.Series,有一种情况例外,那就是当输入或输出列是一个structType时,此时类型为pandas.DataFrame。
import pandas as pd from pyspark.sql.functions import pandas_udf @pandas_udf("col1 string, col2 long") def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame: s3[‘col2‘] = s1 + s2.str.len() return s3 # Create a Spark DataFrame that has three columns including a struct column. df = spark.createDataFrame( [[1, "a string", ("a nested string",)]], "long_col long, string_col string, struct_col struct<col1:string>") df.printSchema() # root # |-- long_column: long (nullable = true) # |-- string_column: string (nullable = true) # |-- struct_column: struct (nullable = true) # | |-- col1: string (nullable = true) df.select(func("long_col", "string_col", "struct_col")).printSchema() # |-- func(long_col, string_col, struct_col): struct (nullable = true) # | |-- col1: string (nullable = true) # | |-- col2: long (nullable = true)
Series to Series
在内部,spark通过将column分为多个batch来执行pands udf,完成计算后,将结果进行拼接。
以下pandas udf实现两列相乘。
import pandas as pd from pyspark.sql.functions import col, pandas_udf from pyspark.sql.types import LongType def multiply_func(a: pd.Series, b: pd.Series) -> pd.Series: return a * b multiply = pandas_udf(multiply_func, returnType=LongType()) x = pd.Series([1, 2, 3]) print(multiply_func(x, x)) # 0 1 # 1 4 # 2 9 # dtype: int64 df = spark.createDataFrame(pd.DataFrame(x, columns=["x"])) df.select(multiply(col("x"), col("x"))).show() # +-------------------+ # |multiply_func(x, x)| # +-------------------+ # | 1| # | 4| # | 9| # +-------------------+
@pandas_udf("long") def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: # Do some expensive initialization with a state state = very_expensive_initialization() for x in iterator: # Use that state for whole iterator. yield calculate_with_state(x, state) df.select(calculate("value")).show()
from typing import Iterator import pandas as pd from pyspark.sql.functions import pandas_udf pdf = pd.DataFrame([1, 2, 3], columns=["x"]) df = spark.createDataFrame(pdf) # Declare the function and create the UDF @pandas_udf("long") def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: for x in iterator: yield x + 1 df.select(plus_one("x")).show() # +-----------+ # |plus_one(x)| # +-----------+ # | 2| # | 3| # | 4| # +-----------+
from typing import Iterator, Tuple import pandas as pd from pyspark.sql.functions import pandas_udf pdf = pd.DataFrame([1, 2, 3], columns=["x"]) df = spark.createDataFrame(pdf) # Declare the function and create the UDF @pandas_udf("long") def multiply_two_cols( iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]: for a, b in iterator: yield a * b df.select(multiply_two_cols("x", "x")).show() # +-----------------------+ # |multiply_two_cols(x, x)| # +-----------------------+ # | 1| # | 4| # | 9| # +-----------------------+
import pandas as pd from pyspark.sql.functions import pandas_udf from pyspark.sql import Window df = spark.createDataFrame( [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")) # Declare the function and create the UDF @pandas_udf("double") def mean_udf(v: pd.Series) -> float: return v.mean() df.select(mean_udf(df[‘v‘])).show() # +-----------+ # |mean_udf(v)| # +-----------+ # | 4.2| # +-----------+ df.groupby("id").agg(mean_udf(df[‘v‘])).show() # +---+-----------+ # | id|mean_udf(v)| # +---+-----------+ # | 1| 1.5| # | 2| 6.0| # +---+-----------+ w = Window .partitionBy(‘id‘) .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) df.withColumn(‘mean_v‘, mean_udf(df[‘v‘]).over(w)).show() # +---+----+------+ # | id| v|mean_v| # +---+----+------+ # | 1| 1.0| 1.5| # | 1| 2.0| 1.5| # | 2| 3.0| 6.0| # | 2| 5.0| 6.0| # | 2|10.0| 6.0| # +---+----+------+
在pyspark中使用pandas udf/apache Arrow
原文:https://www.cnblogs.com/zcsh/p/14840836.html