使用 spark scala 根据条件对列值求和

Sum a column values based on a condition using spark scala

我有一个这样的数据框:

JoiKey period Age Amount
Jk1 2022-02 2 200
Jk1 2022-02 3 450
Jk2 2022-03 5 500
Jk3 2022-03 0 200
Jk2 2022-02 8 300
Jk3 2022-03 9 200
Jk2 2022-04 1 100

有什么方法可以使用 spark scala 根据条件创建两个新列。

列金额(年龄 <= 3)>> 年龄 > 3 的金额总和列金额(年龄 > 3)>> 年龄 <= 3 的金额总和

需要按 Joinkey 和 Period 分组并删除“Age”和“Amount”列

期望的输出将是:

JoiKey period Amount (Age <= 3) Amount (Age > 3)
Jk1 2022-02 650 0
Jk2 2022-03 0 500
Jk2 2022-02 0 300
Jk2 2022-04 100 0
Jk3 2022-03 200 200

当然可以,但是您希望您的数据如何?如果您希望输出类似于:

Age   Amount    A    B
 2     200     500  1450
 3     450     500  1450
 5     500     500  1450
 0     200     500  1450
 8     300     500  1450
 9     200     500  1450
 1     100     500  1450

那么这是一个 windowed 聚合函数(windowing over sum)。 windowing 函数用于放置所有行的聚合值(在本例中)。

df
  .withColumn(
    "A",
    sum(when(col("Age") lt 3, col("Amount")).otherwise(lit(0)))
  .over()
)
.withColumn(
    "B",
    sum(when(col("Age") >= 3, col("Amount")).otherwise(lit(0)))
  .over()
)

请注意,使用不带分区的 over window 函数根本没有性能,请使用分区。 这是输出:

22/04/25 23:54:01 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+---+------+---+----+
|Age|Amount|  A|   B|
+---+------+---+----+
|  2|   200|500|1450|
|  3|   450|500|1450|
|  5|   500|500|1450|
|  0|   200|500|1450|
|  8|   300|500|1450|
|  9|   200|500|1450|
|  1|   100|500|1450|
+---+------+---+----+

更新: 所以在你更新问题后,我建议你这样做:

df
  .groupBy(
    col("JoinKey"), col("period"), expr("Age < 3").as("under3")
  ).agg(sum(col("Amount")) as "grouped_age_sum")
  .withColumn("A", sum(when(col("under3") === true, col("grouped_age_sum")).otherwise(lit(0)))
  .over()
  )
  .withColumn("B", sum(when(col("under3") === false, col("grouped_age_sum")).otherwise(lit(0)))
  .over()
  ).drop("grouped_age_sum", "under3")
  .groupBy(col("JoinKey"), col("period")).min()
  .withColumnRenamed("min(A)", "A")
  .withColumnRenamed("min(B)", "B")
  .show

请注意,关于分区的同样事情仍然存在,我有一些样本数据并且并不真正需要性能(它还会向解决方案添加一些依赖于逻辑的样板),但你应该这样做,这里是输出:

22/04/26 22:36:48 WARN WindowExec: No Partition Defined for Window operation! Moving all data to a single partition, this can cause serious performance degradation.
+-------+-------+---+----+
|JoinKey| period|  A|   B|
+-------+-------+---+----+
|    JK1|2022-02|500|1450|
|    JK2|2022-03|500|1450|
|    JK3|2022-03|500|1450|
|    JK2|2022-02|500|1450|
|    JK2|2022-04|500|1450|
+-------+-------+---+----+

更新#2:

在您提供了更清晰的解释之后:所以您只需要使用 2 个简单的聚合函数进行分组:

df
    .groupBy(col("JoinKey"), col("period"))
    .agg(
      sum(when(col("Age") lt 4, col("Amount")).otherwise(lit(0))).as("Amount (Age <= 3)"),
      sum(when(col("Age") gt 3, col("Amount")).otherwise(lit(0))).as("Amount (Age > 3)")
    )

输出:

+-------+-------+-----------------+----------------+
|JoinKey| period|Amount (Age <= 3)|Amount (Age > 3)|
+-------+-------+-----------------+----------------+
|    JK1|2022-02|              650|               0|
|    JK2|2022-03|                0|             500|
|    JK3|2022-03|              200|             200|
|    JK2|2022-02|                0|             300|
|    JK2|2022-04|              100|               0|
+-------+-------+-----------------+----------------+