在 scala spark 中分组并保存重叠列的最大值

Group by and save the max value with overlapping columns in scala spark

我有这样的数据:

id,start,expiration,customerid,content
1,13494,17358,0001,whateveriwanthere
2,14830,28432,0001,somethingelsewoo
3,11943,19435,0001,yes
4,39271,40231,0002,makingfakedata
5,01321,02143,0002,morefakedata

在上面的数据中,我想按 customerid 分组以重叠 startexpiration(本质上只是合并间隔)。我通过按客户 ID 分组,然后聚合 first("start")max("expiration").

成功地做到了这一点
df.groupBy("customerid").agg(first("start"), max("expiration"))

但是,这会完全删除 id。我想保存最大过期行的 id。例如,我希望我的输出看起来像这样:

id,start,expiration,customerid
2,11934,28432,0001
4,39271,40231,0002
5,01321,02143,0002

我不确定如何为具有最长过期时间的行添加 id 列。

您可以结合使用累积条件和和 lag 函数来定义标记重叠行的 group 列。然后,简单地按 customerid + group 分组,得到最小值 start 和最大值 expiration。要获得与最大到期日期关联的 id 值,您可以将此技巧与结构排序一起使用:

import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy("customerid").orderBy("start")

val result = df.withColumn(
    "group",
    sum(
      when(
        col("start").between(lag("start", 1).over(w), lag("expiration", 1).over(w)),
        0
      ).otherwise(1)
    ).over(w)
).groupBy("customerid", "group").agg(
    min(col("start")).as("start"),
    max(struct(col("expiration"), col("id"))).as("max")
).select("max.id", "customerid", "start", "max.expiration")

result.show
//+---+----------+-----+----------+
//| id|customerid|start|expiration|
//+---+----------+-----+----------+
//|  5|      0002|01321|     02143|
//|  4|      0002|39271|     40231|
//|  2|      0001|11943|     28432|
//+---+----------+-----+----------+