Pyspark Dataframe 的 GroupBy 替代方案?
Alternative to GroupBy for Pyspark Dataframe?
我有这样的数据集:
timestamp vars
2 [1,2]
2 [1,2]
3 [1,2,3]
3 [1,2]
我想要这样的数据框。基本上,上述数据框中的每个值都是一个索引,该值的频率是该索引处的值。此计算针对每个唯一时间戳完成。
timestamp vars
2 [0, 2, 2]
3 [0,2,2,1]
现在,我按时间戳和 aggregrating/flattening vars 进行分组(以获得类似于(时间戳 2 的 1,2,1,2 或时间戳的 1,2,3,1,2 3) 然后我有一个使用 collections.Counter 的 udf 来获取键->值字典。然后我将这个字典转换成我想要的格式。
groupBy/agg 可以任意大(数组大小可以达到数百万),这似乎是 Window 函数的一个很好的用例,但我不确定如何放置它一起。
认为还值得一提的是,我尝试过重新分区、转换为 RDD 并使用 groupByKey。两者在大型数据集上都非常慢(>24 小时)。
编辑: 正如评论中所讨论的,原始方法的问题可能来自 count
使用触发不必要数据扫描的过滤器或聚合函数。下面我们在创建最终数组列之前分解数组并进行聚合(计数):
from pyspark.sql.functions import collect_list, struct
df = spark.createDataFrame([(2,[1,2]), (2,[1,2]), (3,[1,2,3]), (3,[1,2])],['timestamp', 'vars'])
df.selectExpr("timestamp", "explode(vars) as var") \
.groupby('timestamp','var') \
.count() \
.groupby("timestamp") \
.agg(collect_list(struct("var","count")).alias("data")) \
.selectExpr(
"timestamp",
"transform(data, x -> x.var) as indices",
"transform(data, x -> x.count) as values"
).selectExpr(
"timestamp",
"transform(sequence(0, array_max(indices)), i -> IFNULL(values[array_position(indices,i)-1],0)) as new_vars"
).show(truncate=False)
+---------+------------+
|timestamp|new_vars |
+---------+------------+
|3 |[0, 2, 2, 1]|
|2 |[0, 2, 2] |
+---------+------------+
其中:
(1) 我们分解数组并对每个 timestamp
+ var
执行 count()
(2) groupby timestamp
并创建包含两个字段 var
和 count
的结构数组
(3) 将struct数组转换成两个数组:indices和values(类似于我们定义的SparseVector)
(4) 变换序列sequence(0, array_max(indices))
,对于序列中的每个i,用array_position在indices
数组中找到i
的索引,然后检索values
数组中相同位置的值,见下文:
IFNULL(values[array_position(indices,i)-1],0)
注意函数array_position使用从1开始的索引而数组索引是从0开始的,因此我们在上面的表达式中有一个-1
.
旧方法:
(1) 使用变换 + filter/size
from pyspark.sql.functions import flatten, collect_list
df.groupby('timestamp').agg(flatten(collect_list('vars')).alias('data')) \
.selectExpr(
"timestamp",
"transform(sequence(0, array_max(data)), x -> size(filter(data, y -> y = x))) as vars"
).show(truncate=False)
+---------+------------+
|timestamp|vars |
+---------+------------+
|3 |[0, 2, 2, 1]|
|2 |[0, 2, 2] |
+---------+------------+
(2) 使用aggregate函数:
df.groupby('timestamp').agg(flatten(collect_list('vars')).alias('data')) \
.selectExpr("timestamp", """
aggregate(
data,
/* use an array as zero_value, size = array_max(data))+1 and all values are zero */
array_repeat(0, int(array_max(data))+1),
/* increment the ith value of the Array by 1 if i == y */
(acc, y) -> transform(acc, (x,i) -> IF(i=y, x+1, x))
) as vars
""").show(truncate=False)
我有这样的数据集:
timestamp vars
2 [1,2]
2 [1,2]
3 [1,2,3]
3 [1,2]
我想要这样的数据框。基本上,上述数据框中的每个值都是一个索引,该值的频率是该索引处的值。此计算针对每个唯一时间戳完成。
timestamp vars
2 [0, 2, 2]
3 [0,2,2,1]
现在,我按时间戳和 aggregrating/flattening vars 进行分组(以获得类似于(时间戳 2 的 1,2,1,2 或时间戳的 1,2,3,1,2 3) 然后我有一个使用 collections.Counter 的 udf 来获取键->值字典。然后我将这个字典转换成我想要的格式。
groupBy/agg 可以任意大(数组大小可以达到数百万),这似乎是 Window 函数的一个很好的用例,但我不确定如何放置它一起。
认为还值得一提的是,我尝试过重新分区、转换为 RDD 并使用 groupByKey。两者在大型数据集上都非常慢(>24 小时)。
编辑: 正如评论中所讨论的,原始方法的问题可能来自 count
使用触发不必要数据扫描的过滤器或聚合函数。下面我们在创建最终数组列之前分解数组并进行聚合(计数):
from pyspark.sql.functions import collect_list, struct
df = spark.createDataFrame([(2,[1,2]), (2,[1,2]), (3,[1,2,3]), (3,[1,2])],['timestamp', 'vars'])
df.selectExpr("timestamp", "explode(vars) as var") \
.groupby('timestamp','var') \
.count() \
.groupby("timestamp") \
.agg(collect_list(struct("var","count")).alias("data")) \
.selectExpr(
"timestamp",
"transform(data, x -> x.var) as indices",
"transform(data, x -> x.count) as values"
).selectExpr(
"timestamp",
"transform(sequence(0, array_max(indices)), i -> IFNULL(values[array_position(indices,i)-1],0)) as new_vars"
).show(truncate=False)
+---------+------------+
|timestamp|new_vars |
+---------+------------+
|3 |[0, 2, 2, 1]|
|2 |[0, 2, 2] |
+---------+------------+
其中:
(1) 我们分解数组并对每个 timestamp
+ var
(2) groupby timestamp
并创建包含两个字段 var
和 count
(3) 将struct数组转换成两个数组:indices和values(类似于我们定义的SparseVector)
(4) 变换序列sequence(0, array_max(indices))
,对于序列中的每个i,用array_position在indices
数组中找到i
的索引,然后检索values
数组中相同位置的值,见下文:
IFNULL(values[array_position(indices,i)-1],0)
注意函数array_position使用从1开始的索引而数组索引是从0开始的,因此我们在上面的表达式中有一个-1
.
旧方法:
(1) 使用变换 + filter/size
from pyspark.sql.functions import flatten, collect_list
df.groupby('timestamp').agg(flatten(collect_list('vars')).alias('data')) \
.selectExpr(
"timestamp",
"transform(sequence(0, array_max(data)), x -> size(filter(data, y -> y = x))) as vars"
).show(truncate=False)
+---------+------------+
|timestamp|vars |
+---------+------------+
|3 |[0, 2, 2, 1]|
|2 |[0, 2, 2] |
+---------+------------+
(2) 使用aggregate函数:
df.groupby('timestamp').agg(flatten(collect_list('vars')).alias('data')) \
.selectExpr("timestamp", """
aggregate(
data,
/* use an array as zero_value, size = array_max(data))+1 and all values are zero */
array_repeat(0, int(array_max(data))+1),
/* increment the ith value of the Array by 1 if i == y */
(acc, y) -> transform(acc, (x,i) -> IF(i=y, x+1, x))
) as vars
""").show(truncate=False)