如何根据spark scala中的条件进行累计和
How to do cumulative sum based on conditions in spark scala
我有以下数据,final_column
是我想要得到的确切输出。我正在尝试计算 flag
的累计总和,如果 flag
为 0,则想休息,然后将值设置为 0,如下数据
cola date flag final_column
a 2021-10-01 0 0
a 2021-10-02 1 1
a 2021-10-03 1 2
a 2021-10-04 0 0
a 2021-10-05 0 0
a 2021-10-06 0 0
a 2021-10-07 1 1
a 2021-10-08 1 2
a 2021-10-09 1 3
a 2021-10-10 0 0
b 2021-10-01 0 0
b 2021-10-02 1 1
b 2021-10-03 1 2
b 2021-10-04 0 0
b 2021-10-05 0 0
b 2021-10-06 1 1
b 2021-10-07 1 2
b 2021-10-08 1 3
b 2021-10-09 1 4
b 2021-10-10 0 0
我试过了
import org.apache.spark.sql.functions._
df.withColumn("final_column",expr("sum(flag) over(partition by cola order date asc)"))
我试过在 sum 函数中添加类似 case when flag = 0 then 0 else 1 end
的条件,但没有用。
您可以使用 flag
上的条件求和来定义列 group
,然后使用 row_number
和由 cola
和 [=12 分区的 Window =]给出你想要的结果:
import org.apache.spark.sql.expressions.Window
val result = df.withColumn(
"group",
sum(when(col("flag") === 0, 1).otherwise(0)).over(Window.partitionBy("cola").orderBy("date"))
).withColumn(
"final_column",
row_number().over(Window.partitionBy("cola", "group").orderBy("date")) - 1
).drop("group")
result.show
//+----+-----+----+------------+
//|cola| date|flag|final_column|
//+----+-----+----+------------+
//| b|44201| 0| 0|
//| b|44202| 1| 1|
//| b|44203| 1| 2|
//| b|44204| 0| 0|
//| b|44205| 0| 0|
//| b|44206| 1| 1|
//| b|44207| 1| 2|
//| b|44208| 1| 3|
//| b|44209| 1| 4|
//| b|44210| 0| 0|
//| a|44201| 0| 0|
//| a|44202| 1| 1|
//| a|44203| 1| 2|
//| a|44204| 0| 0|
//| a|44205| 0| 0|
//| a|44206| 0| 0|
//| a|44207| 1| 1|
//| a|44208| 1| 2|
//| a|44209| 1| 3|
//| a|44210| 0| 0|
//+----+-----+----+------------+
row_number() - 1
在这种情况下等同于 sum(col("flag"))
因为标志值总是 0 或 1。所以上面的 final_column
也可以写成:
.withColumn(
"final_column",
sum(col("flag")).over(Window.partitionBy("cola", "group").orderBy("date"))
)
我有以下数据,final_column
是我想要得到的确切输出。我正在尝试计算 flag
的累计总和,如果 flag
为 0,则想休息,然后将值设置为 0,如下数据
cola date flag final_column
a 2021-10-01 0 0
a 2021-10-02 1 1
a 2021-10-03 1 2
a 2021-10-04 0 0
a 2021-10-05 0 0
a 2021-10-06 0 0
a 2021-10-07 1 1
a 2021-10-08 1 2
a 2021-10-09 1 3
a 2021-10-10 0 0
b 2021-10-01 0 0
b 2021-10-02 1 1
b 2021-10-03 1 2
b 2021-10-04 0 0
b 2021-10-05 0 0
b 2021-10-06 1 1
b 2021-10-07 1 2
b 2021-10-08 1 3
b 2021-10-09 1 4
b 2021-10-10 0 0
我试过了
import org.apache.spark.sql.functions._
df.withColumn("final_column",expr("sum(flag) over(partition by cola order date asc)"))
我试过在 sum 函数中添加类似 case when flag = 0 then 0 else 1 end
的条件,但没有用。
您可以使用 flag
上的条件求和来定义列 group
,然后使用 row_number
和由 cola
和 [=12 分区的 Window =]给出你想要的结果:
import org.apache.spark.sql.expressions.Window
val result = df.withColumn(
"group",
sum(when(col("flag") === 0, 1).otherwise(0)).over(Window.partitionBy("cola").orderBy("date"))
).withColumn(
"final_column",
row_number().over(Window.partitionBy("cola", "group").orderBy("date")) - 1
).drop("group")
result.show
//+----+-----+----+------------+
//|cola| date|flag|final_column|
//+----+-----+----+------------+
//| b|44201| 0| 0|
//| b|44202| 1| 1|
//| b|44203| 1| 2|
//| b|44204| 0| 0|
//| b|44205| 0| 0|
//| b|44206| 1| 1|
//| b|44207| 1| 2|
//| b|44208| 1| 3|
//| b|44209| 1| 4|
//| b|44210| 0| 0|
//| a|44201| 0| 0|
//| a|44202| 1| 1|
//| a|44203| 1| 2|
//| a|44204| 0| 0|
//| a|44205| 0| 0|
//| a|44206| 0| 0|
//| a|44207| 1| 1|
//| a|44208| 1| 2|
//| a|44209| 1| 3|
//| a|44210| 0| 0|
//+----+-----+----+------------+
row_number() - 1
在这种情况下等同于 sum(col("flag"))
因为标志值总是 0 或 1。所以上面的 final_column
也可以写成:
.withColumn(
"final_column",
sum(col("flag")).over(Window.partitionBy("cola", "group").orderBy("date"))
)