使用 pyspark 进行递归计算

Recursive computation with pypsark

我希望能够在 window 内计算两列之间的总和,并在该总和变为奇数时修改列的值。因此,对该列的修改实际上会修改总和等。 但是,我不知道如何以有效的方式“逐行”扫描我的数据。 你有小费吗?

为了清楚起见,我附上了一些示例行和我想要实现的目标: 我的 window 将基于 ID

my_data = spark.createDataFrame([
    (1, 1, 0), 
    (1, 0, 0), 
    (1, 0, 1), 
    (1, 0, 0), 
    (1, 0, 1), 
    (1, 0, 0), 
    (1, 1, 0),
    (1, 0, 0),
    (1, 0, 1),
    (1, 0, 0), 
    (1, 1, 0),
    (1, 0, 0),
    (1, 1, 0),
    (1, 0, 1),
],
    ['ID','flag_1','flag_2'])

因此,我的问题是导出总和,同时在总和变为奇数时修改 flag_2sum 是预期的结果,flag_2_resultsflag_2 的“清理”版本,

my_data = spark.createDataFrame([
    (1, 1, 0, 0, 1), 
    (1, 0, 0, 0, 1), 
    (1, 0, 1, 1, 2), 
    (1, 0, 0, 0, 2), 
    (1, 0, 1, 0, 2), 
    (1, 0, 0, 0, 2), 
    (1, 1, 0, 0, 3),
    (1, 0, 0, 0, 3),
    (1, 0, 1, 1, 4),
    (1, 0, 0, 0, 4),
    (1, 1, 0, 0, 5),
    (1, 0, 0, 0, 5),
    (1, 1, 0, 0, 6),
    (1, 0, 1, 0, 6),],
    ['ID','flag_1','flag_2', 'flag_2_results', 'sum'])

感谢您的帮助,

根据您最后的评论,您没有那么多行要处理。然后,我建议您仅在 "flag1+flag2>0" :

行上使用 UDF
from pyspark.sql import functions as F, Window as W, types as T


df = my_data.groupBy("ID").agg(
    F.collect_list(F.struct(F.col("posTime"), F.col("flag_1"), F.col("flag_2"))).alias(
        "data"
    )
)

schm = T.ArrayType(
    T.StructType(
        [
            T.StructField("posTime", T.IntegerType()),
            T.StructField("flag_1", T.IntegerType()),
            T.StructField("flag_2", T.IntegerType()),
            T.StructField("flag_2_result", T.IntegerType()),
            T.StructField("sum", T.IntegerType()),
        ]
    )
)


@F.udf(schm)
def process(data):
    accumulator = 0
    out = []
    data.sort(key=lambda x: x["posTime"])
    for l in data:
        flag_2_result = 0
        accumulator += l["flag_1"]
        if l["flag_2"] and accumulator % 2 == 1:
            accumulator += l["flag_2"]
            flag_2_result = 1
        out.append((l["posTime"], l["flag_1"], l["flag_2"], flag_2_result, accumulator))
    return out


df.select("ID", F.explode(process(F.col("data"))).alias("data")).select(
    "ID", "data.*"
).show()

结果:

+---+-------+------+------+-------------+---+                                   
| ID|posTime|flag_1|flag_2|flag_2_result|sum|
+---+-------+------+------+-------------+---+
|  1|      1|     1|     0|            0|  1|
|  1|      2|     0|     0|            0|  1|
|  1|      3|     0|     1|            1|  2|
|  1|      4|     0|     0|            0|  2|
|  1|      5|     0|     1|            0|  2|
|  1|      6|     0|     0|            0|  2|
|  1|      7|     1|     0|            0|  3|
|  1|      8|     0|     0|            0|  3|
|  1|      9|     0|     1|            1|  4|
|  1|     10|     0|     0|            0|  4|
|  1|     11|     1|     0|            0|  5|
|  1|     12|     0|     0|            0|  5|
|  1|     13|     1|     0|            0|  6|
|  1|     14|     0|     1|            0|  6|
+---+-------+------+------+-------------+---+