根据值对数据集行进行分组

group the dataset rows based on the value

我需要基于某些条件的数据集的组行。我的输入数据集就像 df1 :

+------------+----------+----------------------------+------------------+-------+
|     col_1  |     col_2|col_3                       |            tp    |  range|
+------------+----------+----------------------------+------------------+-------+
|MP          |W         |                           X|10                |]0,3]  |
|MP          |W         |                           X|20                |]12,30]|
|MP          |W         |                           X|18                |]12,30]|
|MP          |W         |                           X|18                |]0,3]  |
|MP          |W         |                           X|30                |]0,3]  |
|MP          |W         |                           X|50                |]12,30]|
|MP          |W         |                           X|18                |]12,30]|
|MP          |W         |                           X|60                |]12,30]|
|MP          |W         |                           X|50                |]12,30]|
|MP          |W         |                           X|70                |]12,30]|
|MP          |W         |                           X|18                |]12,30]|
|MP          |W         |                           X|90                |]12,30]|
|MP          |W         |                           X|18                |]36,48]|
|MP          |W         |                           X|18                |]36,48]|
|MP          |W         |                           X|18                |]12,30]|
|MP          |W         |                           X|180               |]12,30]|
|MP          |W         |                           X|18                |]36,48]|
|MP          |W         |                           X|18                |]12,30]|
|MP          |W         |                         S2E|19                |]24,36]|
|MP          |W         |                         S2E|40                |]24,36]|
+------------+----------+----------------------------+------------------+-------+

我想做的是:

  1. 按范围(最后一列)对 df1 的行进行分组 [df = df1.select("*").groupby("col_1", "col_2", "col_3", "tp", "范围"]
  2. 对于相同范围内的行,创建子组,其中同一子组的 2 yield (column name = tp) 之间的比率小于 2 [即 tp(i-1)/tp(i) < 2 或tp(i-2)/tp(i) < 2]

在范围 ]12,30] 的输出中,我将得到类似的内容:

+------------+----------+----------------------------+------------------+-------+------------+
|     col_1  |     col_2|col_3                       |            tp    |  range|  subgroup  |
+------------+----------+----------------------------+------------------+-------+------------+
|MP          |W         |                           X|20                |]12,30]|subgroup_1  |
|MP          |W         |                           X|18                |]12,30]|subgroup_1  |
|MP          |W         |                           X|50                |]12,30]|subgroup_2  |
|MP          |W         |                           X|18                |]12,30]|subgroup_1  |
|MP          |W         |                           X|60                |]12,30]|subgroup_2  |
|MP          |W         |                           X|50                |]12,30]|subgroup_2  |
|MP          |W         |                           X|70                |]12,30]|subgroup_2  |
|MP          |W         |                           X|90                |]12,30]|subgroup_2  |
|MP          |W         |                           X|180               |]12,30]|subgroup_3  |
+------------+----------+----------------------------+------------------+-------+------------+

有人有解决办法吗?我在 Spark Java.

工作

首先,col_1col_2col_3range 列无关紧要。它们可以通过 group 列抽象出来。

想法是使用 window 函数按 tp 值对每个 window 中的行进行排序,然后:

  1. 为每一行创建一个行号,稍后将用作子组 ID。
  2. 计算每一行与其前一行之间的比率
  3. 如果比值大于等于2,则使用当前行的行号作为子组id;否则,继承上一行的子组 ID。

scala 中的代码,但应该演示这个想法:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val data = Seq(
  (1, 20),
  (1, 18),
  (1, 50),
  (1, 18),
  (1, 60),
  (1, 50),
  (1, 70),
  (1, 90),
  (1, 180),
  (1, 360)
) toDF ("group", "tp")

val windowSpec = Window.partitionBy($"group").orderBy($"tp")
val df = data
  .withColumn("lag_tp", lag($"tp", 1, 0).over(windowSpec))
  .withColumn("row_num", row_number.over(windowSpec))
  .withColumn("reci_yield", $"lag_tp" / $"tp")
  .withColumn("yield_ge_2", $"reci_yield" <= 0.5)
  .withColumn("subGroup", 
                // When yield >= 2 detected, get the current row number as subGroup id
                when($"yield_ge_2" === true, $"row_num") 
                .otherwise(
                  // otherwise, get the last non-null subGroup id.
                   last(
                     when($"yield_ge_2"===true, $"row_num"), 
                     ignoreNulls = true
                   ).over(windowSpec)
                )
             )
  // drop intermediate columns
  .drop("row_num", "lag_tp", "reci_yield", "yield_ge_2")

df.show(false)

输出:

+-----+---+--------+
|group|tp |subGroup|
+-----+---+--------+
|1    |18 |1       |
|1    |18 |1       |
|1    |20 |1       |
|1    |50 |4       |
|1    |50 |4       |
|1    |60 |4       |
|1    |70 |4       |
|1    |90 |4       |
|1    |180|9       |
|1    |360|10      |
+-----+---+--------+

来源: