如何根据spark scala中的条件进行累计和

How to do cumulative sum based on conditions in spark scala

我有以下数据,final_column 是我想要得到的确切输出。我正在尝试计算 flag 的累计总和,如果 flag 为 0,则想休息,然后将值设置为 0,如下数据

cola date       flag final_column
a   2021-10-01  0   0
a   2021-10-02  1   1
a   2021-10-03  1   2
a   2021-10-04  0   0
a   2021-10-05  0   0
a   2021-10-06  0   0
a   2021-10-07  1   1
a   2021-10-08  1   2
a   2021-10-09  1   3
a   2021-10-10  0   0
b   2021-10-01  0   0
b   2021-10-02  1   1
b   2021-10-03  1   2
b   2021-10-04  0   0
b   2021-10-05  0   0
b   2021-10-06  1   1
b   2021-10-07  1   2
b   2021-10-08  1   3
b   2021-10-09  1   4
b   2021-10-10  0   0

我试过了

import org.apache.spark.sql.functions._

df.withColumn("final_column",expr("sum(flag) over(partition by cola order date asc)"))

我试过在 sum 函数中添加类似 case when flag = 0 then 0 else 1 end 的条件,但没有用。

您可以使用 flag 上的条件求和来定义列 group,然后使用 row_number 和由 cola 和 [=12 分区的 Window =]给出你想要的结果:

import org.apache.spark.sql.expressions.Window

val result = df.withColumn(
    "group",
    sum(when(col("flag") === 0, 1).otherwise(0)).over(Window.partitionBy("cola").orderBy("date"))
).withColumn(
    "final_column",
    row_number().over(Window.partitionBy("cola", "group").orderBy("date")) - 1
).drop("group")

result.show

//+----+-----+----+------------+
//|cola| date|flag|final_column|
//+----+-----+----+------------+
//|   b|44201|   0|           0|
//|   b|44202|   1|           1|
//|   b|44203|   1|           2|
//|   b|44204|   0|           0|
//|   b|44205|   0|           0|
//|   b|44206|   1|           1|
//|   b|44207|   1|           2|
//|   b|44208|   1|           3|
//|   b|44209|   1|           4|
//|   b|44210|   0|           0|
//|   a|44201|   0|           0|
//|   a|44202|   1|           1|
//|   a|44203|   1|           2|
//|   a|44204|   0|           0|
//|   a|44205|   0|           0|
//|   a|44206|   0|           0|
//|   a|44207|   1|           1|
//|   a|44208|   1|           2|
//|   a|44209|   1|           3|
//|   a|44210|   0|           0|
//+----+-----+----+------------+

row_number() - 1 在这种情况下等同于 sum(col("flag")) 因为标志值总是 0 或 1。所以上面的 final_column 也可以写成:

.withColumn(
    "final_column",
    sum(col("flag")).over(Window.partitionBy("cola", "group").orderBy("date"))
)