使用 pyspark 进行递归计算
Recursive computation with pypsark
我希望能够在 window 内计算两列之间的总和,并在该总和变为奇数时修改列的值。因此,对该列的修改实际上会修改总和等。
但是,我不知道如何以有效的方式“逐行”扫描我的数据。
你有小费吗?
为了清楚起见,我附上了一些示例行和我想要实现的目标:
我的 window 将基于 ID
列
my_data = spark.createDataFrame([
(1, 1, 0),
(1, 0, 0),
(1, 0, 1),
(1, 0, 0),
(1, 0, 1),
(1, 0, 0),
(1, 1, 0),
(1, 0, 0),
(1, 0, 1),
(1, 0, 0),
(1, 1, 0),
(1, 0, 0),
(1, 1, 0),
(1, 0, 1),
],
['ID','flag_1','flag_2'])
因此,我的问题是导出总和,同时在总和变为奇数时修改 flag_2
。 sum
是预期的结果,flag_2_results
是 flag_2
的“清理”版本,
my_data = spark.createDataFrame([
(1, 1, 0, 0, 1),
(1, 0, 0, 0, 1),
(1, 0, 1, 1, 2),
(1, 0, 0, 0, 2),
(1, 0, 1, 0, 2),
(1, 0, 0, 0, 2),
(1, 1, 0, 0, 3),
(1, 0, 0, 0, 3),
(1, 0, 1, 1, 4),
(1, 0, 0, 0, 4),
(1, 1, 0, 0, 5),
(1, 0, 0, 0, 5),
(1, 1, 0, 0, 6),
(1, 0, 1, 0, 6),],
['ID','flag_1','flag_2', 'flag_2_results', 'sum'])
- 原始 n°3:我们保留
flag_2 = 1
因为 sum
是奇数。
- 原始 n°5:我们不保留
flag_2 = 1
,因为 sum
是偶数,因此 sum
直到 flag_1 = 1
.[=37 才改变=]
- Last raw : 我们不保留
flag_2 = 1
(即使它是 flag_1 = 1
之后的第一个)因为它会导致奇数累积和
感谢您的帮助,
根据您最后的评论,您没有那么多行要处理。然后,我建议您仅在 "flag1+flag2>0" :
行上使用 UDF
from pyspark.sql import functions as F, Window as W, types as T
df = my_data.groupBy("ID").agg(
F.collect_list(F.struct(F.col("posTime"), F.col("flag_1"), F.col("flag_2"))).alias(
"data"
)
)
schm = T.ArrayType(
T.StructType(
[
T.StructField("posTime", T.IntegerType()),
T.StructField("flag_1", T.IntegerType()),
T.StructField("flag_2", T.IntegerType()),
T.StructField("flag_2_result", T.IntegerType()),
T.StructField("sum", T.IntegerType()),
]
)
)
@F.udf(schm)
def process(data):
accumulator = 0
out = []
data.sort(key=lambda x: x["posTime"])
for l in data:
flag_2_result = 0
accumulator += l["flag_1"]
if l["flag_2"] and accumulator % 2 == 1:
accumulator += l["flag_2"]
flag_2_result = 1
out.append((l["posTime"], l["flag_1"], l["flag_2"], flag_2_result, accumulator))
return out
df.select("ID", F.explode(process(F.col("data"))).alias("data")).select(
"ID", "data.*"
).show()
结果:
+---+-------+------+------+-------------+---+
| ID|posTime|flag_1|flag_2|flag_2_result|sum|
+---+-------+------+------+-------------+---+
| 1| 1| 1| 0| 0| 1|
| 1| 2| 0| 0| 0| 1|
| 1| 3| 0| 1| 1| 2|
| 1| 4| 0| 0| 0| 2|
| 1| 5| 0| 1| 0| 2|
| 1| 6| 0| 0| 0| 2|
| 1| 7| 1| 0| 0| 3|
| 1| 8| 0| 0| 0| 3|
| 1| 9| 0| 1| 1| 4|
| 1| 10| 0| 0| 0| 4|
| 1| 11| 1| 0| 0| 5|
| 1| 12| 0| 0| 0| 5|
| 1| 13| 1| 0| 0| 6|
| 1| 14| 0| 1| 0| 6|
+---+-------+------+------+-------------+---+
我希望能够在 window 内计算两列之间的总和,并在该总和变为奇数时修改列的值。因此,对该列的修改实际上会修改总和等。 但是,我不知道如何以有效的方式“逐行”扫描我的数据。 你有小费吗?
为了清楚起见,我附上了一些示例行和我想要实现的目标:
我的 window 将基于 ID
列
my_data = spark.createDataFrame([
(1, 1, 0),
(1, 0, 0),
(1, 0, 1),
(1, 0, 0),
(1, 0, 1),
(1, 0, 0),
(1, 1, 0),
(1, 0, 0),
(1, 0, 1),
(1, 0, 0),
(1, 1, 0),
(1, 0, 0),
(1, 1, 0),
(1, 0, 1),
],
['ID','flag_1','flag_2'])
因此,我的问题是导出总和,同时在总和变为奇数时修改 flag_2
。 sum
是预期的结果,flag_2_results
是 flag_2
的“清理”版本,
my_data = spark.createDataFrame([
(1, 1, 0, 0, 1),
(1, 0, 0, 0, 1),
(1, 0, 1, 1, 2),
(1, 0, 0, 0, 2),
(1, 0, 1, 0, 2),
(1, 0, 0, 0, 2),
(1, 1, 0, 0, 3),
(1, 0, 0, 0, 3),
(1, 0, 1, 1, 4),
(1, 0, 0, 0, 4),
(1, 1, 0, 0, 5),
(1, 0, 0, 0, 5),
(1, 1, 0, 0, 6),
(1, 0, 1, 0, 6),],
['ID','flag_1','flag_2', 'flag_2_results', 'sum'])
- 原始 n°3:我们保留
flag_2 = 1
因为sum
是奇数。 - 原始 n°5:我们不保留
flag_2 = 1
,因为sum
是偶数,因此sum
直到flag_1 = 1
.[=37 才改变=] - Last raw : 我们不保留
flag_2 = 1
(即使它是flag_1 = 1
之后的第一个)因为它会导致奇数累积和
感谢您的帮助,
根据您最后的评论,您没有那么多行要处理。然后,我建议您仅在 "flag1+flag2>0" :
行上使用 UDFfrom pyspark.sql import functions as F, Window as W, types as T
df = my_data.groupBy("ID").agg(
F.collect_list(F.struct(F.col("posTime"), F.col("flag_1"), F.col("flag_2"))).alias(
"data"
)
)
schm = T.ArrayType(
T.StructType(
[
T.StructField("posTime", T.IntegerType()),
T.StructField("flag_1", T.IntegerType()),
T.StructField("flag_2", T.IntegerType()),
T.StructField("flag_2_result", T.IntegerType()),
T.StructField("sum", T.IntegerType()),
]
)
)
@F.udf(schm)
def process(data):
accumulator = 0
out = []
data.sort(key=lambda x: x["posTime"])
for l in data:
flag_2_result = 0
accumulator += l["flag_1"]
if l["flag_2"] and accumulator % 2 == 1:
accumulator += l["flag_2"]
flag_2_result = 1
out.append((l["posTime"], l["flag_1"], l["flag_2"], flag_2_result, accumulator))
return out
df.select("ID", F.explode(process(F.col("data"))).alias("data")).select(
"ID", "data.*"
).show()
结果:
+---+-------+------+------+-------------+---+
| ID|posTime|flag_1|flag_2|flag_2_result|sum|
+---+-------+------+------+-------------+---+
| 1| 1| 1| 0| 0| 1|
| 1| 2| 0| 0| 0| 1|
| 1| 3| 0| 1| 1| 2|
| 1| 4| 0| 0| 0| 2|
| 1| 5| 0| 1| 0| 2|
| 1| 6| 0| 0| 0| 2|
| 1| 7| 1| 0| 0| 3|
| 1| 8| 0| 0| 0| 3|
| 1| 9| 0| 1| 1| 4|
| 1| 10| 0| 0| 0| 4|
| 1| 11| 1| 0| 0| 5|
| 1| 12| 0| 0| 0| 5|
| 1| 13| 1| 0| 0| 6|
| 1| 14| 0| 1| 0| 6|
+---+-------+------+------+-------------+---+