在 PySpark 数据框中的组中的列上应用函数

Apply a function over a column in a group in PySpark dataframe

我有一个像这样的 PySpark 数据框,

+----------+--------+---------+
|id_       | p      |   a     |
+----------+--------+---------+
|  1       | 4      |   12    |
|  1       | 3      |   14    |
|  1       | -7     |   16    |
|  1       | 5      |   11    |
|  1       | -20    |   90    |
|  1       | 5      |   120   |
|  2       |  11    |   267   |
|  2       | -98    |   124   |
|  2       | -87    |   120   |
|  2       | -1     |   44    |
|  2       |  5     |   1     |
|  2       |  7     |   23    |
-------------------------------

我也有这样的python功能,

def fun(x):
    total = 0
    result = np.empty_like(x)
    for i, y in enumerate(x):
        total += (y)
        if total < 0:
            total = 0
        result[i] = total

    return result

我想在列 id_ 上对 PySpark 数据框进行分组,并在列 p 上应用函数 fun

我想要

spark_df.groupBy('id_')['p'].apply(fun)

我目前正在 pyarrow 的帮助下使用 pandas udf 执行此操作,这对我的应用程序来说效率不高。

我要找的结果是,

[4, 7, 0, 5, 0, 5, 11, -98, -87, -1, 5, 7]

这是我正在寻找的结果数据框,

+----------+--------+---------+
|id_       | p      |   a     |
+----------+--------+---------+
|  1       | 4      |   12    |
|  1       | 7      |   14    |
|  1       | 0      |   16    |
|  1       | 5      |   11    |
|  1       | 0      |   90    |
|  1       | 5      |   120   |
|  2       |  11    |   267   |
|  2       | 0      |   124   |
|  2       | 0      |   120   |
|  2       | 0      |   44    |
|  2       |  5     |   1     |
|  2       |  12    |   23    |
-------------------------------

有没有一种直接的方法可以用 pyspark API 本身来做到这一点。?

我可以在 id_ 上分组时使用 collect_listp 聚合并列到一个列表中,然后使用 udf 并使用 explode 来获取结果数据框中我需要的 p 列。

但是如何保留我的数据框中的其他列。?

是的,您可以将上述 python 函数转换为 Pyspark UDF。 由于您正在 returning 整数数组,因此将 return 类型指定为 ArrayType(IntegerType()).

很重要

下面是代码,

from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, IntegerType, collect_list

@udf(returnType=ArrayType(IntegerType()))
def fun(x):
    total = 0
    result = np.empty_like(x)
    for i, y in enumerate(x):
        total += (y)
        if total < 0:
            total = 0
        result[i] = total
    return result.tolist()    # Convert NumPy Array to Python List

由于您 udf 的输入必须是列表,让我们根据 'id' 对数据进行分组并将行转换为数组。

df = df.groupBy('id_').agg(collect_list('p'))
df = df.toDF('id_', 'p_')    # Assign a new alias name 'p_'
df.show(truncate=False)

输入数据:

+---+------------------------+
|id_|collect_list(p)         |
+---+------------------------+
|1  |[4, 3, -7, 5, -20, 5]   |
|2  |[11, -98, -87, -1, 5, 7]|
+---+------------------------+

接下来,我们将 udf 应用于此数据,

df.select('id_', fun(df.p_)).show(truncate=False)

输出:

+---+--------------------+
|id_|fun(p_)             |
+---+--------------------+
|1  |[4, 7, 0, 5, 0, 5]  |
|2  |[11, 0, 0, 0, 5, 12]|
+---+--------------------+

我通过以下步骤成功实现了我需要的结果,

我的 DataFrame 看起来像这样,

+---+---+---+
|id_|  p|  a|
+---+---+---+
|  1|  4| 12|
|  1|  3| 14|
|  1| -7| 16|
|  1|  5| 11|
|  1|-20| 90|
|  1|  5|120|
|  2| 11|267|
|  2|-98|124|
|  2|-87|120|
|  2| -1| 44|
|  2|  5|  1|
|  2|  7| 23|
+---+---+---+

我将对 id_ 上的数据框进行分组并收集我想使用 collect_list 将函数应用于列表的列并像这样应用函数,

agg_df = df.groupBy('id_').agg(F.collect_list('p').alias('collected_p'))
agg_df = agg_df.withColumn('new', fun('collected_p'))

我现在想以某种方式将 agg_df 合并到我的原始数据框。为此,我将首先使用 explode 获取行中 new 列中的值。

agg_df = agg_df.withColumn('exploded', F.explode('new'))

为了合并,我将使用 monotonically_increasing_id 为原始数据帧和 agg_df 生成 id。从那以后,我将为每个数据帧制作 idx,因为两个数据帧的 monotonically_increasing_id 不相同。

agg_df = agg_df.withColumn('id_mono', F.monotonically_increasing_id())
df = df.withColumn('id_mono', F.monotonically_increasing_id())

w = Window().partitionBy(F.lit(0)).orderBy('id_mono')

df = df.withColumn('idx', F.row_number().over(w))
agg_df = agg_df.withColumn('idx', F.row_number().over(w))

df = df.join(agg_df.select('idx', 'exploded'), ['idx']).drop('id_mono', 'idx')


+---+---+---+--------+
|id_|  p|  a|exploded|
+---+---+---+--------+
|  1|  4| 12|       4|
|  1|  3| 14|       7|
|  1| -7| 16|       0|
|  1|  5| 11|       5|
|  1|-20| 90|       0|
|  1|  5|120|       5|
|  2| 11|267|      11|
|  2|-98|124|       0|
|  2|-87|120|       0|
|  2| -1| 44|       0|
|  2|  5|  1|       5|
|  2|  7| 23|      12|
+---+---+---+--------+

我不确定这是一个直接的方法。如果有人可以为此提出任何优化建议,那就太好了。