PySpark:将不同的 window 大小应用于 pyspark 中的数据框

PySpark: applying varying window sizes to a dataframe in pyspark

我有一个如下所示的 spark 数据框。

date ID window_size qty
01/01/2020 1 2 1
02/01/2020 1 2 2
03/01/2020 1 2 3
04/01/2020 1 2 4
01/01/2020 2 3 1
02/01/2020 2 3 2
03/01/2020 2 3 3
04/01/2020 2 3 4

我正在尝试对数据框中的每个 ID 应用大小为 window_size 的滚动 window 并获得滚动总和。基本上我正在计算滚动总和(pandas 中的 pd.groupby.rolling(window=n).sum()),其中 window 大小 (n) 可以按组更改。

预期输出

date ID window_size qty rolling_sum
01/01/2020 1 2 1 null
02/01/2020 1 2 2 3
03/01/2020 1 2 3 5
04/01/2020 1 2 4 7
01/01/2020 2 3 1 null
02/01/2020 2 3 2 null
03/01/2020 2 3 3 6
04/01/2020 2 3 4 9

我正在努力寻找一种在大型数据帧(+- 350M 行)上运行且速度足够快的解决方案。

我试过的

我尝试了下面的解决方案 :

想法是先使用 sf.collect_list,然后正确地分割 ArrayType 列。

import pyspark.sql.types as st
import pyspark.sql.function as sf

window = Window.partitionBy('id').orderBy(params['date'])

output = (
    sdf
    .withColumn("qty_list", sf.collect_list('qty').over(window))
    .withColumn("count", sf.count('qty').over(window))
    .withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
                                 .otherwise(sf.slice('qty_list', sf.col('count'), sf.col('window_size'))))
).show()

但是这会产生以下错误:

TypeError: Column is not iterable

我也试过使用 sf.expr 如下

window = Window.partitionBy('id').orderBy(params['date'])

output = (
    sdf
    .withColumn("qty_list", sf.collect_list('qty').over(window))
    .withColumn("count", sf.count('qty').over(window))
    .withColumn("rolling_sum", sf.when(sf.col('count') < sf.col('window_size'), None)
                                 .otherwise(sf.expr("slice('window_size', 'count', 'window_size')")))
).show()

产生:

data type mismatch: argument 1 requires array type, however, ''qty_list'' is of string type.; line 1 pos 0;

我尝试将 qty_list 列手动转换为 ArrayType(IntegerType()),结果相同。

我尝试使用 UDF,但在 1.5 小时左右后失败并出现多个内存不足错误。

问题

  1. 阅读火花 documentation 向我建议我应该能够将列传递给 sf.slice(),我做错了什么吗? TypeError 来自哪里?

  2. 有没有更好的方法不用sf.collect_list()and/orsf.slice()就可以达到我想要的效果?

  3. 如果所有其他方法都失败了,使用 udf 执行此操作的最佳方法是什么?我尝试了同一个 udf 的不同版本,并试图确保 udf 是 spark 必须执行的最后一个操作,但都失败了。

关于您遇到的错误:

  1. 第一个表示您不能使用 DataFrame API 函数将列传递给 slice(除非您有 Spark 3.1+)。但是当您尝试在 SQL 表达式中使用它时,您已经掌握了它。
  2. 发生第二个错误是因为您传递了 expr 中引用的列名。它应该是 slice(qty_list, count, window_size),否则 Spark 会将它们视为字符串,因此会出现错误消息。

也就是说,您差不多明白了,您需要更改切片表达式以获得正确的数组大小,然后使用 aggregate 函数对结果数组的值求和。试试这个:

from pyspark.sql import Window
import pyspark.sql.functions as F

w = Window.partitionBy('id').orderBy('date')

output = df.withColumn("qty_list", F.collect_list('qty').over(w)) \
    .withColumn("rn", F.row_number().over(w)) \
    .withColumn(
        "qty_list",
        F.when(
            F.col('rn') < F.col('window_size'),
            None
        ).otherwise(F.expr("slice(qty_list, rn-window_size+1, window_size)"))
    ).withColumn(
        "rolling_sum",
        F.expr("aggregate(qty_list, 0D, (acc, x) -> acc + x)").cast("int")
    ).drop("qty_list", "rn")

output.show()
#+----------+---+-----------+---+-----------+
#|      date| ID|window_size|qty|rolling_sum|
#+----------+---+-----------+---+-----------+
#|01/01/2020|  1|          2|  1|       null|
#|02/01/2020|  1|          2|  2|          3|
#|03/01/2020|  1|          2|  3|          5|
#|04/01/2020|  1|          2|  4|          7|
#|01/01/2020|  2|          3|  1|       null|
#|02/01/2020|  2|          3|  2|       null|
#|03/01/2020|  2|          3|  3|          6|
#|04/01/2020|  2|          3|  4|          9|
#+----------+---+-----------+---+-----------+