Pyspark:如何编写复杂的 Dataframe 算法问题(带条件求和)

Pyspark: How to code Complicated Dataframe algorithm problem (summing with condition)

我的数据框如下所示:

TEST_schema = StructType([StructField("date", StringType(), True),\
                              StructField("Trigger", StringType(), True),\
                              StructField("value", FloatType(), True),\
                              StructField("col1", IntegerType(), True),
                             StructField("col2", IntegerType(), True),
                             StructField("want", FloatType(), True)])
TEST_data = [('2020-08-01','T',0.0,3,5,0.5),('2020-08-02','T',0.0,-1,4,0.0),('2020-08-03','T',0.0,-1,3,0.0),('2020-08-04','F',0.2,3,3,0.7),('2020-08-05','T',0.3,1,4,0.9),\
                 ('2020-08-06','F',0.2,-1,3,0.0),('2020-08-07','T',0.2,-1,4,0.0),('2020-08-08','T',0.5,-1,5,0.0),('2020-08-09','T',0.0,-1,5,0.0)]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema)
TEST_df = TEST_df.withColumn("date",to_date("date", 'yyyy-MM-dd'))
TEST_df.show() 
+----------+-------+-----+----+----+
|      date|Trigger|value|col1|col2|
+----------+-------+-----+----+----+
|2020-08-01|      T|  0.0|   3|   5| 
|2020-08-02|      T|  0.0|  -1|   4| 
|2020-08-03|      T|  0.0|  -1|   3| 
|2020-08-04|      F|  0.2|   3|   3| 
|2020-08-05|      T|  0.3|   1|   4|
|2020-08-06|      F|  0.2|  -1|   3|
|2020-08-07|      T|  0.2|  -1|   4|
|2020-08-08|      T|  0.5|  -1|   5| 
|2020-08-09|      T|  0.0|  -1|   5|
+----------+-------+-----+----+----+

date : 排序很好

Trigger : 仅 TF

value :任何随机小数(浮点)值

col1:表示天数,不能小于-1。** -1<= col1 < infinity**

col2 :表示天数,不能为负数。 col2 >= 0

**计算逻辑**

如果col1 == -1, then return 0,否则Trigger == T,下图有助于理解逻辑。

如果我们看“红色”,+3来自于2020-08-01的col1==3的col1,意思是我们跳了3行,同时也取差异 (col2 - col1) -1 = ( 5-3) -1 = 1.(在 2020-08-01)1 表示对下一个值 0.2 + 0.3 = 0.5 求和。同样的逻辑适用于“蓝色”

“绿色”是当trigger == "F"时取(col2 -1)=3-1 =2(2020-08-04),2表示下两个之和值。这是 0.2+0.3+0.2 = 0.7

编辑:

如果我根本不需要条件怎么办,假设我们有这个 df

TEST_schema = StructType([StructField("date", StringType(), True),\
                              StructField("value", FloatType(), True),\
                             StructField("col2", IntegerType(), True)])
TEST_data = [('2020-08-01',0.0,5),('2020-08-02',0.0,4),('2020-08-03',0.0,3),('2020-08-04',0.2,3),('2020-08-05',0.3,4),\
                 ('2020-08-06',0.2,3),('2020-08-07',0.2,4),('2020-08-08',0.5,5),('2020-08-09',0.0,5)]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema)
TEST_df = TEST_df.withColumn("date",to_date("date", 'yyyy-MM-dd'))
TEST_df.show() 


+----------+-----+----+
|      date|value|col2|
+----------+-----+----+
|2020-08-01|  0.0|   5|
|2020-08-02|  0.0|   4|
|2020-08-03|  0.0|   3|
|2020-08-04|  0.2|   3|
|2020-08-05|  0.3|   4|
|2020-08-06|  0.2|   3|
|2020-08-07|  0.2|   4|
|2020-08-08|  0.5|   5|
|2020-08-09|  0.0|   5|
+----------+-----+----+

当我们有 Trigger == "F" 条件时,同样的逻辑适用,所以 col2 -1 但在这种情况下没有条件。

IIUC,我们可以使用Windows函数collect_list获取所有相关行,按date对结构数组进行排序,然后根据[=26=进行聚合] 这个数组。每个 slicestart_idxspan 可以根据以下定义:

  1. 如果col1 = -1, start_idx = 1span = 0,所以没有聚合
  2. else if 触发器 = 'F', then start_idx = 1span = col2
  3. else start_idx = col1+1 and span = col2-col1

注意函数切片的 index1-based.

代码:

from pyspark.sql.functions import to_date, sort_array, collect_list, struct, expr
from pyspark.sql import Window

w1 = Window.orderBy('date').rowsBetween(0, Window.unboundedFollowing)

# columns used to do calculations, date must be the first field for sorting purpose
cols = ["date", "value", "start_idx", "span"]

df_new = (TEST_df
    .withColumn('start_idx', expr("IF(col1 = -1 OR Trigger = 'F', 1, col1+1)")) 
    .withColumn('span', expr("IF(col1 = -1, 0, IF(Trigger = 'F', col2, col2-col1))")) 
    .withColumn('dta', sort_array(collect_list(struct(*cols)).over(w1))) 
    .withColumn("want1", expr("aggregate(slice(dta,start_idx,span), 0D, (acc,x) -> acc+x.value)"))
)

结果:

df_new.show()
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
|      date|Trigger|value|col1|col2|want|start_idx|span|                 dta|             want1|
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
|2020-08-01|      T|  0.0|   3|   5| 0.5|        4|   2|[[2020-08-01, T, ...|0.5000000149011612|
|2020-08-02|      T|  0.0|  -1|   4| 0.0|        1|   0|[[2020-08-02, T, ...|               0.0|
|2020-08-03|      T|  0.0|  -1|   3| 0.0|        1|   0|[[2020-08-03, T, ...|               0.0|
|2020-08-04|      F|  0.2|   3|   3| 0.7|        1|   3|[[2020-08-04, F, ...|0.7000000178813934|
|2020-08-05|      T|  0.3|   1|   4| 0.9|        2|   3|[[2020-08-05, T, ...|0.9000000059604645|
|2020-08-06|      F|  0.2|  -1|   3| 0.0|        1|   0|[[2020-08-06, F, ...|               0.0|
|2020-08-07|      T|  0.2|  -1|   4| 0.0|        1|   0|[[2020-08-07, T, ...|               0.0|
|2020-08-08|      T|  0.5|  -1|   5| 0.0|        1|   0|[[2020-08-08, T, ...|               0.0|
|2020-08-09|      T|  0.0|  -1|   5| 0.0|        1|   0|[[2020-08-09, T, ...|               0.0|
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+

一些解释:

  1. slice函数除了目标数组外还需要两个参数。在我们的代码中,start_idx 是起始索引,span 是切片的长度。在代码中,我使用 IF 语句计算 start_idxspan 基于原始 post.

    中的图表规格
  2. collect_list + sort_array 在 [=172 上的结果数组=] w1 覆盖从当前行到 Window 末尾的行(参见 w1 赋值)。然后我们在 aggregate 函数中使用 slice 函数来只检索必要的数组项。

  3. SparkSQL 内置函数 aggregate 采用以下形式:

     aggregate(expr, start, merge, finish) 
    

    可以跳过第 4 个参数 finish。在我们的例子中,它可以重新格式化为(您可以复制以下内容以替换 expr .withColumn('want1', expr(""" .... """) 中的代码):

     aggregate(
       /* targeting array, use slice function to take only part of the array `dta` */
       slice(dta,start_idx,span), 
       /* start, zero_value used for reduce */
       0D, 
       /* merge, similar to reduce function */
       (acc,x) -> acc+x.value,
       /* finish, skipped in the post, but you can do some post-processing here, for example, round-up the result from merge */
       acc -> round(acc, 2)
     )
    

    aggregate 函数的工作方式类似于 Python 中的 reduce 函数,第二个参数是零值(0Ddouble(0) 的快捷方式,即转换聚合变量的数据类型 acc).

  4. 如评论中所述,if col2 < col1 where Trigger = 'T'col1 != -1 存在,它会产生负值 span 在当前代码中。在这种情况下,我们应该使用 full-size Window spec:

     w1 = Window.orderBy('date').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)        
    

    并用array_position求出当前行的位置()然后计算start_idx 基于此位置。