PySpark 中每组的滚动相关性和平均值(最后 3 个)

Rolling correlation and average (last 3) Per Group in PySpark

我有一个这样的数据框

data = [(("ID1", 1, 5)), (("ID1", 2, 6)), (("ID1", 3, 7)),
    (("ID1", 4, 4)), (("ID1", 5, 2)), (("ID1", 6, 2)),
    (("ID2", 1, 4)), (("ID2", 2, 6)), (("ID2", 3, 1)), (("ID2", 4, 1)), (("ID2", 5, 4))]
df = spark.createDataFrame(data, ["ID", "colA", "colB"])
df.show()

+---+----+----+
| ID|colA|colB|
+---+----+----+
|ID1|   1|   5|
|ID1|   2|   6|
|ID1|   3|   7|
|ID1|   4|   4|
|ID1|   5|   2|
|ID1|   6|   2|
|ID2|   1|   4|
|ID2|   2|   6|
|ID2|   3|   1|
|ID2|   4|   1|
|ID2|   5|   4|
+---+----+----+

我想计算每组最后 3 个元素的最后 3 个相关性和平均值。

Hence for ID1, for first element (5) - Average = 5, corr = 0
for ID1, for first 2 element (5, 6) - Average = 5.5, corr with colA = 1
for ID1, for first 3 element (5, 6, 7) - Average = 6, corr with colA = 1
for ID1, for elements (6, 7, 4) - Average = 5.66, corr with colA = -0.65


Expected output is like this

    +---+----+----+----------+---------+
    | ID|colA|colB|corr_last3|avg_last3|
    +---+----+----+----------+---------+
    |ID1|   1|   5|         0|        5|
    |ID1|   2|   6|         1|      5.5|
    |ID1|   3|   7|         1|        6|
    |ID1|   4|   4|     -0.65|     5.66|
    |ID1|   5|   2|     -0.99|     4.33|
    |ID1|   6|   2|     -0.86|     2.66|
    |ID2|   1|   4|         0|        4|
    |ID2|   2|   6|         1|        5|
    |ID2|   3|   1|     -0.59|     3.66|
    |ID2|   4|   1|     -0.86|     2.66|
    |ID2|   5|   4|      0.86|        2|
    +---+----+----+----------+---------+

您可以使用内置函数 avgcorr 来实现,这里是 scala 解决方案:

df
  .withColumn("indices",row_number().over(Window.partitionBy($"ID").orderBy($"colA")))
  .withColumn("corr_last3", when($"indices">1,corr($"indices",$"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow))).otherwise(0.0))
  .withColumn("avg_last3", avg($"colB").over(Window.partitionBy($"ID").orderBy($"colA").rowsBetween(-2L,Window.currentRow)))
  .drop($"indices")
  .orderBy($"ID",$"colA")
  .show() 

给出:

+---+----+----+-------------------+------------------+
| ID|colA|colB|         corr_last3|         avg_last3|
+---+----+----+-------------------+------------------+
|ID1|   1|   5|                0.0|               5.0|
|ID1|   2|   6|                1.0|               5.5|
|ID1|   3|   7|                1.0|               6.0|
|ID1|   4|   4|-0.6546536707079772| 5.666666666666667|
|ID1|   5|   2|-0.9933992677987828| 4.333333333333333|
|ID1|   6|   2|-0.8660254037844386|2.6666666666666665|
|ID2|   1|   4|                0.0|               4.0|
|ID2|   2|   6|                1.0|               5.0|
|ID2|   3|   1|-0.5960395606792697|3.6666666666666665|
|ID2|   4|   1|-0.8660254037844387|2.6666666666666665|
|ID2|   5|   4| 0.8660254037844387|               2.0|
+---+----+----+-------------------+------------------+

Pyspark版本的答案是这个

from pyspark.sql import Window
from pyspark.sql.functions import rank, corr, when, mean, col, round

df = df\
      .withColumn("indices",rank().over(Window.partitionBy("ID").orderBy("colA")))\
      .withColumn("corr_last3", when(col("indices") > 1, corr(col("indices"), col("colB"))
                                     .over(Window.partitionBy("ID").orderBy("colA")
                                           .rangeBetween(-2, Window.currentRow))).otherwise(0.0))\
      .withColumn("avg_last3", mean(col("colB")).over(Window.partitionBy("ID").orderBy("colA").rangeBetween(-2, Window.currentRow)))\
      .drop(col("indices"))\
      .orderBy("ID","colA")

df = df.withColumn("corr_last3", round(col("corr_last3"), 3))\
       .withColumn("avg_last3", round(col("corr_last3"), 3))
df.show() 


+---+----+----+----------+---------+
| ID|colA|colB|corr_last3|avg_last3|
+---+----+----+----------+---------+
|ID1|   1|   5|       0.0|      0.0|
|ID1|   2|   6|       1.0|      1.0|
|ID1|   3|   7|       1.0|      1.0|
|ID1|   4|   4|    -0.655|   -0.655|
|ID1|   5|   2|    -0.993|   -0.993|
|ID1|   6|   2|    -0.866|   -0.866|
|ID2|   1|   4|       0.0|      0.0|
|ID2|   2|   6|       1.0|      1.0|
|ID2|   3|   1|    -0.596|   -0.596|
|ID2|   4|   1|    -0.866|   -0.866|
|ID2|   5|   4|     0.866|    0.866|
+---+----+----+----------+---------+