Pyspark 中基于时间的 window 函数

Time based window function in Pyspark

我的目标是计算另一列,保持与原始 DataFrame 相同的行数,我可以在其中显示每个用户过去 30 天的平均余额。

我想这可以使用 Window 函数来完成,按用户分区并以某种方式限制当前日期和 30 天前之间的行,但我不知道如何在 PySpark 中实现它.

我有以下 Spark DataFrame:

userId date balance
A 09/06/2020 100
A 03/07/2020 200
A 05/08/2020 600
A 30/08/2020 1000
A 15/09/2020 500
B 03/01/2020 100
B 05/04/2020 200
B 29/04/2020 600
B 01/05/2020 1600

我想要的输出 DataFrame 是:

userId date balance mean_last_30days_balance
A 09/06/2020 100 100
A 03/07/2020 200 150
A 05/08/2020 600 600
A 30/08/2020 1000 800
A 15/09/2020 500 750
B 03/01/2020 100 100
B 05/04/2020 200 200
B 29/04/2020 600 400
B 01/05/2020 1600 800
from datetime import datetime
from pyspark.sql import types as T

data = [("A",datetime.strptime("09/06/2020",'%d/%m/%Y'),100),
        ("A",datetime.strptime("03/07/2020",'%d/%m/%Y'),200),
        ("A",datetime.strptime("05/08/2020",'%d/%m/%Y'),600),
        ("A",datetime.strptime("30/08/2020",'%d/%m/%Y'),1000),
        ("A",datetime.strptime("15/09/2020",'%d/%m/%Y'),500),
        ("B",datetime.strptime("03/01/2020",'%d/%m/%Y'),100),
        ("B",datetime.strptime("05/04/2020",'%d/%m/%Y'),200),
        ("B",datetime.strptime("29/04/2020",'%d/%m/%Y'),600),
        ("B",datetime.strptime("01/05/2020",'%d/%m/%Y'),1600)]

schema = T.StructType([T.StructField("userId",T.StringType(),True),
                       T.StructField("date",T.DateType(),True), 
                       T.StructField("balance",T.StringType(),True)
                      ])
 
sdf_prueba = spark.createDataFrame(data=data,schema=schema)
sdf_prueba.printSchema()
sdf_prueba.orderBy(F.col('userId').asc(),F.col('date').asc()).show(truncate=False)

使用范围介于。

sdf_prueba.createOrReplaceTempView("table1")

spark.sql(
    """SELECT *, mean(balance) OVER (
        PARTITION BY userid 
        ORDER BY CAST(date AS timestamp)  
        RANGE BETWEEN INTERVAL 30 DAYS PRECEDING AND CURRENT ROW
     ) AS mean FROM table1""").show()


+------+----------+-------+-----+
|userId|      date|balance| mean|
+------+----------+-------+-----+
|     A|2020-06-09|    100|100.0|
|     A|2020-07-03|    200|150.0|
|     A|2020-08-05|    600|600.0|
|     A|2020-08-30|   1000|800.0|
|     A|2020-09-15|    500|750.0|
|     B|2020-01-03|    100|100.0|
|     B|2020-04-05|    200|200.0|
|     B|2020-04-29|    600|400.0|
|     B|2020-05-01|   1600|800.0|
+------+----------+-------+-----+

作为 df,将天数转换为 unix 秒数以达到 rangebetween;

w = Window.partitionBy("userid").orderBy(col("date").cast("timestamp").cast("long")).rangeBetween(-2629743, 0)

sdf_prueba.select(col("*"), mean("balance").over(w).alias("mean")).show()

+------+----------+-------+-----+
|userId|      date|balance| mean|
+------+----------+-------+-----+
|     A|2020-06-09|    100|100.0|
|     A|2020-07-03|    200|150.0|
|     A|2020-08-05|    600|600.0|
|     A|2020-08-30|   1000|800.0|
|     A|2020-09-15|    500|750.0|
|     B|2020-01-03|    100|100.0|
|     B|2020-04-05|    200|200.0|
|     B|2020-04-29|    600|400.0|
|     B|2020-05-01|   1600|800.0|
+------+----------+-------+-----+