Pyspark Dataframe 的 GroupBy 替代方案?

Alternative to GroupBy for Pyspark Dataframe?

我有这样的数据集:

timestamp     vars 
2             [1,2]
2             [1,2]
3             [1,2,3]
3             [1,2]

我想要这样的数据框。基本上,上述数据框中的每个值都是一个索引,该值的频率是该索引处的值。此计算针对每个唯一时间戳完成。

timestamp     vars 
2             [0, 2, 2]
3             [0,2,2,1]

现在,我按时间戳和 aggregrating/flattening vars 进行分组(以获得类似于(时间戳 2 的 1,2,1,2 或时间戳的 1,2,3,1,2 3) 然后我有一个使用 collections.Counter 的 udf 来获取键->值字典。然后我将这个字典转换成我想要的格式。

groupBy/agg 可以任意大(数组大小可以达到数百万),这似乎是 Window 函数的一个很好的用例,但我不确定如何放置它一起。

认为还值得一提的是,我尝试过重新分区、转换为 RDD 并使用 groupByKey。两者在大型数据集上都非常慢(>24 小时)。

编辑: 正如评论中所讨论的,原始方法的问题可能来自 count 使用触发不必要数据扫描的过滤器或聚合函数。下面我们在创建最终数组列之前分解数组并进行聚合(计数):

from pyspark.sql.functions import collect_list, struct  

df = spark.createDataFrame([(2,[1,2]), (2,[1,2]), (3,[1,2,3]), (3,[1,2])],['timestamp', 'vars'])

df.selectExpr("timestamp", "explode(vars) as var") \
    .groupby('timestamp','var') \
    .count() \
    .groupby("timestamp") \
    .agg(collect_list(struct("var","count")).alias("data")) \
    .selectExpr(
        "timestamp",
        "transform(data, x -> x.var) as indices",
        "transform(data, x -> x.count) as values"
    ).selectExpr(
        "timestamp",
        "transform(sequence(0, array_max(indices)), i -> IFNULL(values[array_position(indices,i)-1],0)) as new_vars"
    ).show(truncate=False)
+---------+------------+
|timestamp|new_vars    |
+---------+------------+
|3        |[0, 2, 2, 1]|
|2        |[0, 2, 2]   |
+---------+------------+

其中:

(1) 我们分解数组并对每个 timestamp + var

执行 count()

(2) groupby timestamp 并创建包含两个字段 varcount

的结构数组

(3) 将struct数组转换成两个数组:indices和values(类似于我们定义的SparseVector)

(4) 变换序列sequence(0, array_max(indices)),对于序列中的每个i,用array_positionindices数组中找到i的索引,然后检索values 数组中相同位置的值,见下文:

IFNULL(values[array_position(indices,i)-1],0)

注意函数array_position使用从1开始的索引而数组索引是从0开始的,因此我们在上面的表达式中有一个-1 .

旧方法:

(1) 使用变换 + filter/size

from pyspark.sql.functions import flatten, collect_list

df.groupby('timestamp').agg(flatten(collect_list('vars')).alias('data')) \
  .selectExpr(
    "timestamp", 
    "transform(sequence(0, array_max(data)), x -> size(filter(data, y -> y = x))) as vars"
  ).show(truncate=False)
+---------+------------+
|timestamp|vars        |
+---------+------------+
|3        |[0, 2, 2, 1]|
|2        |[0, 2, 2]   |
+---------+------------+

(2) 使用aggregate函数:

df.groupby('timestamp').agg(flatten(collect_list('vars')).alias('data')) \
   .selectExpr("timestamp", """ 

     aggregate(   
       data,         
       /* use an array as zero_value, size = array_max(data))+1 and all values are zero */
       array_repeat(0, int(array_max(data))+1),       
       /* increment the ith value of the Array by 1 if i == y */
       (acc, y) -> transform(acc, (x,i) -> IF(i=y, x+1, x))       
     ) as vars   

""").show(truncate=False)