如何在 PySpark 中计算具有不同 window 大小的滚动总和

How to calculate rolling sum with varying window sizes in PySpark

我有一个 spark 数据框,其中包含一段时间内某些商店中某些产品的销售预测数据。如何计算 window 大小的下 N 个值的预测的滚动总和?

输入数据

+-----------+---------+------------+------------+---+
| ProductId | StoreId |    Date    | Prediction | N |
+-----------+---------+------------+------------+---+
|         1 |     100 | 2019-07-01 | 0.92       | 2 |
|         1 |     100 | 2019-07-02 | 0.62       | 2 |
|         1 |     100 | 2019-07-03 | 0.89       | 2 |
|         1 |     100 | 2019-07-04 | 0.57       | 2 |
|         2 |     200 | 2019-07-01 | 1.39       | 3 |
|         2 |     200 | 2019-07-02 | 1.22       | 3 |
|         2 |     200 | 2019-07-03 | 1.33       | 3 |
|         2 |     200 | 2019-07-04 | 1.61       | 3 |
+-----------+---------+------------+------------+---+

预期输出数据

+-----------+---------+------------+------------+---+------------------------+
| ProductId | StoreId |    Date    | Prediction | N |       RollingSum       |
+-----------+---------+------------+------------+---+------------------------+
|         1 |     100 | 2019-07-01 | 0.92       | 2 | sum(0.92, 0.62)        |
|         1 |     100 | 2019-07-02 | 0.62       | 2 | sum(0.62, 0.89)        |
|         1 |     100 | 2019-07-03 | 0.89       | 2 | sum(0.89, 0.57)        |
|         1 |     100 | 2019-07-04 | 0.57       | 2 | sum(0.57)              |
|         2 |     200 | 2019-07-01 | 1.39       | 3 | sum(1.39, 1.22, 1.33)  |
|         2 |     200 | 2019-07-02 | 1.22       | 3 | sum(1.22, 1.33, 1.61 ) |
|         2 |     200 | 2019-07-03 | 1.33       | 3 | sum(1.33, 1.61)        |
|         2 |     200 | 2019-07-04 | 1.61       | 3 | sum(1.61)              |
+-----------+---------+------------+------------+---+------------------------+

Python 中有很多关于这个问题的问答,但我在 PySpark 中找不到。

类似问题1
有一个类似的问题 here 但在此一帧大小固定为 3。在提供的答案中使用了 rangeBetween 函数,它只适用于固定大小的帧,所以我不能将它用于不同的大小。

类似问题2
还有一个类似的问题here。在这一篇中,建议为所有可能的大小编写案例,但它不适用于我的案例,因为我不知道我需要计算多少个不同的帧大小。

解决方案尝试 1
我尝试使用 pandas udf:

来解决问题
rolling_sum_predictions = predictions.groupBy('ProductId', 'StoreId').apply(calculate_rolling_sums)

calculate_rolling_sums 是一个 pandas udf,我在其中解决了 python 中的问题。该解决方案适用于少量测试数据。然而,当数据变大时(在我的例子中,输入 df 有大约 1B 行),计算时间很长。

解决方案尝试 2
我使用了上述 类似问题 1 的解决方法。我已经计算出最大可能的 N,使用它创建列表,然后通过对列表进行切片来计算预测总和。

predictions = predictions.withColumn('DayIndex', F.rank().over(Window.partitionBy('ProductId', 'StoreId').orderBy('Date')))

# find the biggest period
biggest_period = predictions.agg({"N": "max"}).collect()[0][0]

# calculate rolling predictions starting from the DayIndex
w = (Window.partitionBy(F.col("ProductId"), F.col("StoreId")).orderBy(F.col('DayIndex')).rangeBetween(0, biggest_period - 1))
rolling_prediction_lists = predictions.withColumn("next_preds", F.collect_list("Prediction").over(w))

# calculate rolling forecast sums
pred_sum_udf = udf(lambda preds, period: float(np.sum(preds[:period])), FloatType())
rolling_pred_sums = rolling_prediction_lists \
    .withColumn("RollingSum", pred_sum_udf("next_preds", "N"))

这个解决方案也适用于测试数据。我还没有机会用原始数据对其进行测试,但无论它是否有效,我都不喜欢这个解决方案。有没有更聪明的方法来解决这个问题?

如果您使用的是 spark 2.4+,则可以使用新的 higher-order array functions sliceaggregate 来有效地实现您的要求,而无需任何 UDF:

summed_predictions = predictions\
   .withColumn("summed", F.collect_list("Prediction").over(Window.partitionBy("ProductId", "StoreId").orderBy("Date").rowsBetween(Window.currentRow, Window.unboundedFollowing))\
   .withColumn("summed", F.expr("aggregate(slice(summed,1,N), cast(0 as double), (acc,d) -> acc + d)"))

summed_predictions.show()
+---------+-------+-------------------+----------+---+------------------+
|ProductId|StoreId|               Date|Prediction|  N|            summed|
+---------+-------+-------------------+----------+---+------------------+
|        1|    100|2019-07-01 00:00:00|      0.92|  2|              1.54|
|        1|    100|2019-07-02 00:00:00|      0.62|  2|              1.51|
|        1|    100|2019-07-03 00:00:00|      0.89|  2|              1.46|
|        1|    100|2019-07-04 00:00:00|      0.57|  2|              0.57|
|        2|    200|2019-07-01 00:00:00|      1.39|  3|              3.94|
|        2|    200|2019-07-02 00:00:00|      1.22|  3|              4.16|
|        2|    200|2019-07-03 00:00:00|      1.33|  3|2.9400000000000004|
|        2|    200|2019-07-04 00:00:00|      1.61|  3|              1.61|
+---------+-------+-------------------+----------+---+------------------+

它可能不是最好的,但您可以获得不同的 "N" 列值并像下面这样循环。

val arr = df.select("N").distinct.collect

for(n <- arr) df.filter(col("N") ===  n.get(0))
.withColumn("RollingSum",sum(col("Prediction"))
.over(Window.partitionBy("N").orderBy("N").rowsBetween(Window.currentRow, n.get(0).toString.toLong-1))).show

这会给你喜欢的:

+---------+-------+----------+----------+---+------------------+
|ProductId|StoreId|      Date|Prediction|  N|        RollingSum|
+---------+-------+----------+----------+---+------------------+
|        2|    200|2019-07-01|      1.39|  3|              3.94|
|        2|    200|2019-07-02|      1.22|  3|              4.16|
|        2|    200|2019-07-03|      1.33|  3|2.9400000000000004|
|        2|    200|2019-07-04|      1.61|  3|              1.61|
+---------+-------+----------+----------+---+------------------+

+---------+-------+----------+----------+---+----------+
|ProductId|StoreId|      Date|Prediction|  N|RollingSum|
+---------+-------+----------+----------+---+----------+
|        1|    100|2019-07-01|      0.92|  2|      1.54|
|        1|    100|2019-07-02|      0.62|  2|      1.51|
|        1|    100|2019-07-03|      0.89|  2|      1.46|
|        1|    100|2019-07-04|      0.57|  2|      0.57|
+---------+-------+----------+----------+---+----------+

然后你可以对循环内的所有数据帧进行并集。