Pyspark:滚动中的聚合模式(最频繁)值 window

Pyspark: aggregate mode (most frequent) value in a rolling window

我有一个数据框,如下所示。我想按 device 分组并在每个组中按 start_time 排序。然后,对于组中的每一行,从它之前的 3 行(包括它自己)的 window 中获取最常出现的站点。

columns = ['device', 'start_time', 'station']
data = [("Python", 1, "station_1"), ("Python", 2, "station_2"), ("Python", 3, "station_1"), ("Python", 4, "station_2"), ("Python", 5, "station_2"), ("Python", 6, None)]


test_df = spark.createDataFrame(data).toDF(*columns)
rolling_w = Window.partitionBy('device').orderBy('start_time').rowsBetween(-2, 0)

期望的输出:

+------+----------+---------+--------------------+
|device|start_time|  station|rolling_mode_station|
+------+----------+---------+--------------------+
|Python|         1|station_1|           station_1|
|Python|         2|station_2|           station_2|
|Python|         3|station_1|           station_1|
|Python|         4|station_2|           station_2|
|Python|         5|station_2|           station_2|
|Python|         6|     null|           station_2|
+------+----------+---------+--------------------+

由于 Pyspark 没有 mode() 函数,我知道如何在静态 groupby 中获取最频繁的值,如图 ,但我不知道如何使其适应滚动 window.

您可以使用 collect_list 函数使用定义的 window 从最后 3 行获取站点,然后为每个结果数组计算最频繁的元素。

要获取数组中出现频率最高的元素,您可以将其展开,然后按链接进行分组并计数 post 您已经看到或使用了一些像这样的 UDF:

import pyspark.sql.functions as F

test_df.withColumn(
    "rolling_mode_station",
    F.collect_list("station").over(rolling_w)
).withColumn(
    "rolling_mode_station",
    F.udf(lambda x: max(set(x), key=x.count))(F.col("rolling_mode_station"))
).show()

#+------+----------+---------+--------------------+
#|device|start_time|  station|rolling_mode_station|
#+------+----------+---------+--------------------+
#|Python|         1|station_1|           station_1|
#|Python|         2|station_2|           station_1|
#|Python|         3|station_1|           station_1|
#|Python|         4|station_2|           station_2|
#|Python|         5|station_2|           station_2|
#|Python|         6|     null|           station_2|
#+------+----------+---------+--------------------+