从 Spark DataFrame 中的按列运行创建唯一的分组键

Creating a unique grouping key from column-wise runs in a Spark DataFrame

我有类似的东西,其中 spark 是我的 sparkContext。我在我的 sparkContext 中导入了 implicits._,所以我可以使用 $ 语法:

val df = spark.createDataFrame(Seq(("a", 0L), ("b", 1L), ("c", 1L), ("d", 1L), ("e", 0L), ("f", 1L)))
              .toDF("id", "flag")
              .withColumn("index", monotonically_increasing_id)
              .withColumn("run_key", when($"flag" === 1, $"index").otherwise(0))

df.show

df: org.apache.spark.sql.DataFrame = [id: string, flag: bigint ... 2 more fields]
+---+----+-----+-------+
| id|flag|index|run_key|
+---+----+-----+-------+
|  a|   0|    0|      0|
|  b|   1|    1|      1|
|  c|   1|    2|      2|
|  d|   1|    3|      3|
|  e|   0|    4|      0|
|  f|   1|    5|      5|
+---+----+-----+-------+

我想为 run_key 的每个非零块创建另一个具有唯一分组键的列,等同于此:

+---+----+-----+-------+---+
| id|flag|index|run_key|key|
+---+----+-----+-------+---|
|  a|   0|    0|      0|  0|
|  b|   1|    1|      1|  1|
|  c|   1|    2|      2|  1|
|  d|   1|    3|      3|  1|
|  e|   0|    4|      0|  0|
|  f|   1|    5|      5|  2|
+---+----+-----+-------+---+

它可以是每个 运行 中的第一个值、每个 运行 的平均值,或者其他一些值——只要保证它是唯一的,这并不重要之后我可以对其进行分组,以比较各组之间的其他值。

编辑:顺便说一句,我不需要保留 flag0 的行。

您可以用最大索引标记 "run",其中 flag 小于相关行的索引 0

类似于:

flags = df.filter($"flag" === 0)
  .select("index")
  .withColumnRenamed("index", "flagIndex")
indices = df.select("index").join(flags, df.index > flags.flagIndex)
  .groupBy($"index")
  .agg(max($"index$).as("groupKey"))
dfWithGroups = df.join(indices, Seq("index"))

一种方法是 1) 使用 $"flag" 中的 Window 函数 lag() 创建一个列 $"lag1",2) 创建另一个列 $"switched" 在 $"flag" 被切换的行中使用 $"index" 值,最后 3) 创建通过 [=12 从最后 non-null 行复制 $"switched" 的列=] 和 rowsBetween().

请注意,此解决方案使用 Window 函数而不进行分区,因此可能不适用于大型数据集。

val df = Seq(
  ("a", 0L), ("b", 1L), ("c", 1L), ("d", 1L), ("e", 0L), ("f", 1L),
  ("g", 1L), ("h", 0L), ("i", 0L), ("j", 1L), ("k", 1L), ("l", 1L)
).toDF("id", "flag").
  withColumn("index", monotonically_increasing_id).
  withColumn("run_key", when($"flag" === 1, $"index").otherwise(0))

import org.apache.spark.sql.expressions.Window

df.withColumn( "lag1", lag("flag", 1, -1).over(Window.orderBy("index")) ).
  withColumn( "switched", when($"flag" =!= $"lag1", $"index") ).
  withColumn( "key", last("switched", ignoreNulls = true).over(
    Window.orderBy("index").rowsBetween(Window.unboundedPreceding, 0)
  ) )

// +---+----+-----+-------+----+--------+---+
// | id|flag|index|run_key|lag1|switched|key|
// +---+----+-----+-------+----+--------+---+
// |  a|   0|    0|      0|  -1|       0|  0|
// |  b|   1|    1|      1|   0|       1|  1|
// |  c|   1|    2|      2|   1|    null|  1|
// |  d|   1|    3|      3|   1|    null|  1|
// |  e|   0|    4|      0|   1|       4|  4|
// |  f|   1|    5|      5|   0|       5|  5|
// |  g|   1|    6|      6|   1|    null|  5|
// |  h|   0|    7|      0|   1|       7|  7|
// |  i|   0|    8|      0|   0|    null|  7|
// |  j|   1|    9|      9|   0|       9|  9|
// |  k|   1|   10|     10|   1|    null|  9|
// |  l|   1|   11|     11|   1|    null|  9|
// +---+----+-----+-------+----+--------+---+