pandas udf 作为 pyspark 中的 window 函数

pandas udf as a window function in pyspark

目标是使用 pandas user-defined function 作为 pyspark 中的 window 函数。这是一个最小的例子。

df 是一个 pandas DataFrame 和一个 spark table:

import pandas as pd
from pyspark.sql import SparkSession

df = pd.DataFrame(
    {'x': [1, 1, 2, 2, 2, 3, 3],
     'y': [1, 2, 3, 4, 5, 6, 7]})
spark = SparkSession.builder.getOrCreate()
spark.createDataFrame(df).createOrReplaceTempView('df')

这里是 df 作为 spark table

In [10]: spark.sql('SELECT * FROM df').show()
+---+---+
|  x|  y|
+---+---+
|  1|  1|
|  1|  2|
|  2|  3|
|  2|  4|
|  2|  5|
|  3|  6|
|  3|  7|
+---+---+

最小的例子是实现 y 的累加和除以 x。没有任何 pandas 用户定义函数,如下所示:

dx = spark.sql(f"""
    SELECT x, y,
    SUM(y) OVER (PARTITION BY x ORDER BY y) AS ysum
    FROM df
    ORDER BY x""").toPandas()

然后 dx 就是

In [2]: dx
Out[2]:
   x  y  ysum
0  1  1     1
1  1  2     3
2  2  3     3
3  2  4     7
4  2  5    12
5  3  6     6
6  3  7    13

pandas_udf 做同样的尝试是

from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType

@pandas_udf(returnType=DoubleType())
def func(x: pd.Series) -> pd.Series:
    return x.cumsum()
spark.udf.register('func', func)

dx = spark.sql(f"""
    SELECT x, y,
    func(y) OVER (PARTITION BY x ORDER BY y) AS ysum
    FROM df
    ORDER BY x""").toPandas()

哪个returns这个错误

AnalysisException: Expression 'func(y#1L)' not supported within a window function.;
...

更新 根据 wwnde 的回答,解决方案是

def pdf_cumsum(pdf):
    pdf['ysum'] = pdf['y'].cumsum()
    return pdf
dx = sdf.groupby('x').applyInPandas(pdf_cumsum, schema='x long, y long, ysum long').toPandas()

使用地图 Pandas 中的 mapInPandas 函数 API

sch =df.withColumn('ysum',lit(3)).schema
def cumsum_pdf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    for pdf in iterator:
      yield pdf.assign(ysum=pdf.groupby('x')['y'].cumsum())

df.mapInPandas(cumsum_pdf, schema=sch).show()

结果

+---+---+----+
|  x|  y|ysum|
+---+---+----+
|  1|  1|   1|
|  1|  2|   3|
|  2|  3|   3|
|  2|  4|   7|
|  2|  5|  12|
|  3|  6|   6|
|  3|  7|  13|
+---+---+----+