Pyspark groupBy 多列并使用多个 udf 函数进行聚合

Pyspark groupBy multiple columns and aggregate using multiple udf functions

我想对多列进行分组,然后通过计算每列模式的用户定义函数 (udf) 聚合各种列。我用这个示例代码演示了我的问题:

import pandas as pd
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType, IntegerType

df = pd.DataFrame(columns=['A', 'B', 'C', 'D'])
df["A"] = ["Mon", "Mon", "Mon", "Fri", "Fri", "Fri", "Fri"]
df["B"] = ["Feb", "Feb", "Feb", "May", "May", "May", "May"]
df["C"] = ["x", "y", "y", "m", "n", "r", "r"]
df["D"] = [3, 3, 5, 1, 1, 1, 9]
df_sdf = spark.createDataFrame(df)
df_sdf.show()

+---+---+---+---+
|  A|  B|  C|  D|
+---+---+---+---+
|Mon|Feb|  x|  3|
|Mon|Feb|  y|  3|
|Mon|Feb|  y|  5|
|Fri|May|  m|  1|
|Fri|May|  n|  1|
|Fri|May|  r|  1|
|Fri|May|  r|  9|
+---+---+---+---+

# Custom mode function to get mode value for string list and integer list
def custom_mode(lst): return(max(lst, key=lst.count))
custom_mode_str = udf(custom_mode, StringType())
custom_mode_int = udf(custom_mode, IntegerType())

grp_columns = ["A", "B"]
df_sdf.groupBy(grp_columns).agg(custom_mode_str(col("C")).alias("C"), custom_mode_int(col("D")).alias("D")).distinct().show()

但是,我在上面代码的最后一行收到以下错误:

AnalysisException: expression '`C`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;;

此代码的预期输出是:

+---+---+---+---+
|  A|  B|  C|  D|
+---+---+---+---+
|Mon|Feb|  y|  3|
|Fri|May|  r|  1|
+---+---+---+---+

我搜索了很多但在pyspark中找不到类似这个问题的东西。谢谢你的时间。

您的 UDF 需要 list 但您提供的是 spark 数据框的列。您可以将列表传递给将生成所需结果的函数。

sdf.groupBy(['A', 'B']). \
    agg(custom_mode_str(func.collect_list('C')).alias('C'), 
        custom_mode_int(func.collect_list('D')).alias('D')
        ). \
    show()

# +---+---+---+---+
# |  A|  B|  C|  D|
# +---+---+---+---+
# |Mon|Feb|  y|  3|
# |Fri|May|  r|  1|
# +---+---+---+---+

collect_list() 是这里的关键,因为它将生成一个列表,该列表将与您的 UDF 一起使用。请参阅下面的集合输出。

sdf.groupBy(['A', 'B']). \
    agg(func.collect_list('C').alias('C_collected'), 
        func.collect_list('D').alias('D_collected')
        ). \
    show()

# +---+---+------------+------------+
# |  A|  B| C_collected| D_collected|
# +---+---+------------+------------+
# |Mon|Feb|   [x, y, y]|   [3, 3, 5]|
# |Fri|May|[m, n, r, r]|[1, 1, 1, 9]|
# +---+---+------------+------------+