Spark:如何将行分组为固定大小的数组?

Spark: how to group rows into a fixed size array?

我有一个如下所示的数据集:

+---+
|col|
+---+
|  a|
|  b|
|  c|
|  d|
|  e|
|  f|
|  g|
+---+

我想重新格式化此数据集,以便将行聚合成固定长度的数组,如下所示:

+------+
|   col|
+------+
|[a, b]|
|[c, d]|
|[e, f]|
|   [g]|
+------+

我试过这个:

spark.sql("select collect_list(col) from (select col, row_number() over (order by col) row_number from dataset) group by floor(row_number/2)")

但问题是我的实际数据集太大,无法在 row_number()

的单个分区中处理

如您希望分发此文件,需要执行几个步骤。

以防万一,您希望 运行 代码,我从这里开始:

var df = List(
  "a", "b", "c", "d", "e", "f", "g"
).toDF("col")
val desiredArrayLength = 2

首先,将您的数据框分成一个小的,您可以在单个节点上处理,一个大的,其行数是所需数组大小的倍数(在您的示例中,这是 2)

val nRowsPrune = 1 //number of rows to prune such that remaining dataframe has number of
                   // rows is multiples of the desired length of array
val dfPrune = df.sort(desc("col")).limit(nRowsPrune)
df = df.join(dfPrune,Seq("col"),"left_anti") //separate small from large dataframe

通过构造,可以在小dataframe上套用原代码,

val groupedPruneDf = dfPrune//.withColumn("g",floor((lit(-1)+row_number().over(w))/lit(desiredArrayLength ))) //added -1 as row-number starts from 1
                            //.groupBy("g")
                            .agg( collect_list("col").alias("col"))
                            .select("col")

现在,我们需要找到一种方法来处理剩余的大型数据帧。但是,现在我们确定 df 的行数是数组大小的倍数。 这是我们使用一个绝妙技巧的地方,即使用 repartitionByRange 重新分区。基本上,分区保证保留排序,并且在分区时每个分区将具有相同的大小。 您现在可以收集每个分区中的每个数组,

   val nRows = df.count()
   val maxNRowsPartition = desiredArrayLength //make sure its a multiple of desired array length
   val nPartitions = math.max(1,math.floor(nRows/maxNRowsPartition) ).toInt
   df = df.repartitionByRange(nPartitions, $"col".desc)
          .withColumn("partitionId",spark_partition_id())

    val w = Window.partitionBy($"partitionId").orderBy("col")
    val groupedDf = df
        .withColumn("g",  floor( (lit(-1)+row_number().over(w))/lit(desiredArrayLength ))) //added -1 as row-number starts from 1
        .groupBy("partitionId","g")
        .agg( collect_list("col").alias("col"))
        .select("col")

最后将这两个结果结合起来得到您正在寻找的结果,

val result = groupedDf.union(groupedPruneDf)
result.show(truncate=false)