Pyspark:滚动中的聚合模式(最频繁)值 window
Pyspark: aggregate mode (most frequent) value in a rolling window
我有一个数据框,如下所示。我想按 device
分组并在每个组中按 start_time
排序。然后,对于组中的每一行,从它之前的 3 行(包括它自己)的 window 中获取最常出现的站点。
columns = ['device', 'start_time', 'station']
data = [("Python", 1, "station_1"), ("Python", 2, "station_2"), ("Python", 3, "station_1"), ("Python", 4, "station_2"), ("Python", 5, "station_2"), ("Python", 6, None)]
test_df = spark.createDataFrame(data).toDF(*columns)
rolling_w = Window.partitionBy('device').orderBy('start_time').rowsBetween(-2, 0)
期望的输出:
+------+----------+---------+--------------------+
|device|start_time| station|rolling_mode_station|
+------+----------+---------+--------------------+
|Python| 1|station_1| station_1|
|Python| 2|station_2| station_2|
|Python| 3|station_1| station_1|
|Python| 4|station_2| station_2|
|Python| 5|station_2| station_2|
|Python| 6| null| station_2|
+------+----------+---------+--------------------+
由于 Pyspark 没有 mode()
函数,我知道如何在静态 groupby
中获取最频繁的值,如图 ,但我不知道如何使其适应滚动 window.
您可以使用 collect_list
函数使用定义的 window 从最后 3 行获取站点,然后为每个结果数组计算最频繁的元素。
要获取数组中出现频率最高的元素,您可以将其展开,然后按链接进行分组并计数 post 您已经看到或使用了一些像这样的 UDF:
import pyspark.sql.functions as F
test_df.withColumn(
"rolling_mode_station",
F.collect_list("station").over(rolling_w)
).withColumn(
"rolling_mode_station",
F.udf(lambda x: max(set(x), key=x.count))(F.col("rolling_mode_station"))
).show()
#+------+----------+---------+--------------------+
#|device|start_time| station|rolling_mode_station|
#+------+----------+---------+--------------------+
#|Python| 1|station_1| station_1|
#|Python| 2|station_2| station_1|
#|Python| 3|station_1| station_1|
#|Python| 4|station_2| station_2|
#|Python| 5|station_2| station_2|
#|Python| 6| null| station_2|
#+------+----------+---------+--------------------+
我有一个数据框,如下所示。我想按 device
分组并在每个组中按 start_time
排序。然后,对于组中的每一行,从它之前的 3 行(包括它自己)的 window 中获取最常出现的站点。
columns = ['device', 'start_time', 'station']
data = [("Python", 1, "station_1"), ("Python", 2, "station_2"), ("Python", 3, "station_1"), ("Python", 4, "station_2"), ("Python", 5, "station_2"), ("Python", 6, None)]
test_df = spark.createDataFrame(data).toDF(*columns)
rolling_w = Window.partitionBy('device').orderBy('start_time').rowsBetween(-2, 0)
期望的输出:
+------+----------+---------+--------------------+
|device|start_time| station|rolling_mode_station|
+------+----------+---------+--------------------+
|Python| 1|station_1| station_1|
|Python| 2|station_2| station_2|
|Python| 3|station_1| station_1|
|Python| 4|station_2| station_2|
|Python| 5|station_2| station_2|
|Python| 6| null| station_2|
+------+----------+---------+--------------------+
由于 Pyspark 没有 mode()
函数,我知道如何在静态 groupby
中获取最频繁的值,如图
您可以使用 collect_list
函数使用定义的 window 从最后 3 行获取站点,然后为每个结果数组计算最频繁的元素。
要获取数组中出现频率最高的元素,您可以将其展开,然后按链接进行分组并计数 post 您已经看到或使用了一些像这样的 UDF:
import pyspark.sql.functions as F
test_df.withColumn(
"rolling_mode_station",
F.collect_list("station").over(rolling_w)
).withColumn(
"rolling_mode_station",
F.udf(lambda x: max(set(x), key=x.count))(F.col("rolling_mode_station"))
).show()
#+------+----------+---------+--------------------+
#|device|start_time| station|rolling_mode_station|
#+------+----------+---------+--------------------+
#|Python| 1|station_1| station_1|
#|Python| 2|station_2| station_1|
#|Python| 3|station_1| station_1|
#|Python| 4|station_2| station_2|
#|Python| 5|station_2| station_2|
#|Python| 6| null| station_2|
#+------+----------+---------+--------------------+