在 PySpark 数据框 DF 中的用户级别迭代(循环)
Iterate (loop) at user level in PySpark dataframe DF
假设我在 PySpark 中有以下 DF,其中 UB 和 LB 分别代表上限和下限。
+---------+-----+--------------+------+------+
| user_id | row | currentValue | UB | LB |
+---------+-----+--------------+------+------+
| usr001 | 1 | 12 | 7.2 | 16.8 |
| usr001 | 2 | 20 | 12 | 28 |
| usr001 | 3 | 17 | 10.2 | 23.8 |
| usr001 | 4 | 21 | 12.6 | 29.4 |
| usr001 | 5 | 9 | 5.4 | 12.6 |
| usr001 | 6 | 23 | 13.8 | 32.2 |
| usr002 | 1 | 11 | 6.6 | 15.4 |
| usr002 | 2 | 10 | 6 | 14 |
| usr002 | 3 | 15 | 9 | 21 |
| usr002 | 4 | 3 | 1.8 | 4.2 |
| usr002 | 5 | 4 | 2.4 | 5.6 |
+---------+-----+--------------+------+------+
对于 DF 中的每个用户,我想应用一些 logic/rules 以便 currentValue 可以更新为 updatedValue。 Logic/rules如下:
user_id: usr001
- 对于第 1 行:currentValue = updatedValue(对于所有用户)
- 对于第 2 行:如果 currentValue 在第 1 行的 LB 和 UB 之间(如果 20 在 7.2 和 16.8 之间),则第 2 行的 updatedValue 等于第 1 行的 currentValue(第 2 行的 updatedValue = 12)。否则,updatedValue = currentValue (updatedValue = 20)
- 因为在第 2 行中,updatedValue = currentValue,第 3 行与第 2 行进行比较。
对于第 3 行:如果 currentValue 在第 2 行的 LB 和 UB 范围内(如果 17 在 12 和 28 之间),则第 3 行的 updatedValue 等于第 2 行的 currentValue(第 3 行的 updatedValue = 20)。否则,updatedValue = currentValue (updatedValue = 17)
- 因为在第3行中,第3行中的updatedValue =第2行中的currentValue,第4行与第2行进行比较。
对于第 4 行:如果 currentValue 在第 2 行的 LB 和 UB 之间(如果 21 在 12 和 28 之间),则第 4 行的 updatedValue 等于第 2 行的 currentValue(第 4 行的 updatedValue = 20)。否则,updatedValue = currentValue (updatedValue = 21)
- 因为在第4行中,第4行中的updatedValue =第2行中的currentValue,第5行与第2行进行比较。
对于第 5 行:如果 currentValue 在第 2 行的 LB 和 UB 之内(如果 9 在 12 和 28 之间),则第 5 行的 updatedValue 等于第 2 行的 currentValue(第 5 行的 updatedValue = 20)。否则,updatedValue = currentValue (updatedValue = 9)
- 因为在第 5 行中,updatedValue = currentValue,第 6 行与第 5 行进行比较。
对于第 6 行:如果 currentValue 在第 5 行的 LB 和 UB 范围内(如果 23 在 5.4 和 12.6 之间),则第 6 行的 updatedValue 等于第 2 行的 currentValue(第 5 行的 updatedValue = 20)。否则,updatedValue = currentValue (updatedValue = 9)
确切的规则将适用于 usr002。预期输出如下:
+---------+-----+--------------+------+------+--------------+
| user_id | row | currentValue | UB | LB | updatedValue |
+---------+-----+--------------+------+------+--------------+
| usr001 | 1 | 12 | 7.2 | 16.8 | 12 |
| usr001 | 2 | 20 | 12 | 28 | 20 |
| usr001 | 3 | 17 | 10.2 | 23.8 | 20 |
| usr001 | 4 | 21 | 12.6 | 29.4 | 20 |
| usr001 | 5 | 9 | 5.4 | 12.6 | 9 |
| usr001 | 6 | 23 | 13.8 | 32.2 | 23 |
| usr002 | 1 | 11 | 6.6 | 15.4 | 11 |
| usr002 | 2 | 10 | 6 | 14 | 11 |
| usr002 | 3 | 15 | 9 | 21 | 11 |
| usr002 | 4 | 3 | 1.8 | 4.2 | 3 |
| usr002 | 5 | 4 | 2.4 | 5.6 | 3 |
+---------+-----+--------------+------+------+--------------+
有什么方法可以在 Spark 中实现吗?感谢您的帮助!
Spark: 2.4.4
您可以使用 window 功能。但它不是那么简单。这是对代码和逻辑的逐步解释。
(在下面的代码和解释中uv和updatedValue是不一样的)
1.Read df
df=spark.read.csv(path, header=True, inferSchema=True)
2.Specify window
w=Window.partitionBy("user_id").orderBy("row")
3.Create 将当前值与前一行的 UB 和 LB 进行比较的列,如果它在范围内,则 return 前一行 currentValue 否则 return 同一行 currentValue,让将此列称为 "uv"
df2=df.withColumn("uv",when(col("row")==1,col("currentValue"))
.when(col("currentValue").between(lag("UB",1).over(w),
lag("LB",1).over(w)),lag("currentValue",1).over(w))
.otherwise(col("currentValue"))).orderBy("user_id")
df2:
+-------+---+------------+----+----+---+
|user_id|row|currentValue| UB| LB| uv|
+-------+---+------------+----+----+---+
| usr001| 1| 12| 7.2|16.8| 12|
| usr001| 2| 20|12.0|28.0| 20|
| usr001| 3| 17|10.2|23.8| 20|
| usr001| 4| 21|12.6|29.4| 17|
| usr001| 5| 9| 5.4|12.6| 9|
| usr001| 6| 23|13.8|32.2| 23|
| usr002| 1| 11| 6.6|15.4| 11|
| usr002| 2| 10| 6.0|14.0| 11|
| usr002| 3| 15| 9.0|21.0| 15|
| usr002| 4| 3| 1.8| 4.2| 3|
| usr002| 5| 4| 2.4| 5.6| 3|
+-------+---+------------+----+----+---+
4.This 是主要逻辑,根据第 5 行 (usr001) 的逻辑,首先我们必须检查第 4 行更新值是否填充了第 4 行当前值,如果已填充则将第 5 行值与行进行比较4 个边界,否则我们必须转到填充第 4 行 updatedValue 的行并与这些边界进行比较,以在上述步骤中实现此标记所有值,其中 currentValue==uv.
df3=df2.withColumn("comp_row", when(col("currentValue")==col("uv"), col("row")))
df3:
+-------+---+------------+----+----+---+--------+
|user_id|row|currentValue| UB| LB| uv|comp_row|
+-------+---+------------+----+----+---+--------+
| usr001| 1| 12| 7.2|16.8| 12| 1|
| usr001| 2| 20|12.0|28.0| 20| 2|
| usr001| 3| 17|10.2|23.8| 20| null|
| usr001| 4| 21|12.6|29.4| 17| null|
| usr001| 5| 9| 5.4|12.6| 9| 5|
| usr001| 6| 23|13.8|32.2| 23| 6|
| usr002| 1| 11| 6.6|15.4| 11| 1|
| usr002| 2| 10| 6.0|14.0| 11| null|
| usr002| 3| 15| 9.0|21.0| 15| 3|
| usr002| 4| 3| 1.8| 4.2| 3| 4|
| usr002| 5| 4| 2.4| 5.6| 3| null|
+-------+---+------------+----+----+---+--------+
5.Now 如果我们回填每行的空值,我们将得到每行应该与之比较的行号。
df4 = df3.withColumn("comp_row",last("comp_row",True).over(w))
df4:
+-------+---+------------+----+----+---+--------+
|user_id|row|currentValue| UB| LB| uv|comp_row|
+-------+---+------------+----+----+---+--------+
| usr001| 1| 12| 7.2|16.8| 12| 1|
| usr001| 2| 20|12.0|28.0| 20| 2|
| usr001| 3| 17|10.2|23.8| 20| 2|
| usr001| 4| 21|12.6|29.4| 17| 2|
| usr001| 5| 9| 5.4|12.6| 9| 5|
| usr001| 6| 23|13.8|32.2| 23| 6|
| usr002| 1| 11| 6.6|15.4| 11| 1|
| usr002| 2| 10| 6.0|14.0| 11| 1|
| usr002| 3| 15| 9.0|21.0| 15| 3|
| usr002| 4| 3| 1.8| 4.2| 3| 4|
| usr002| 5| 4| 2.4| 5.6| 3| 4|
+-------+---+------------+----+----+---+--------+
注意:comp_row 的值表示下一行应该与哪一行进行比较,例如:第 4 行(usr001)comp_row 包含 2 表示第 5 行与第 2 行进行比较。
6.Now 我们知道哪一行与哪一行比较,我们需要做的只是获取这些行的边界。为此,我们需要将行与 comp_row 连接起来,这样我们就可以在第 4 行中获得第 2 行的边界。
df5 = df4.select("user_id",col("row").alias("comp_row"),
col("UB").alias("new_UB"),col("LB").alias("new_LB")
,col("currentValue").alias("new_currentValue"))
# Note: Here row is selected as comp_row.
df6=df5.join(df4,["user_id","comp_row"],"inner").orderBy("user_id","row")
df6.select("user_id",
"UB","LB"
,"new_UB","new_LB"
,"currentValue","new_currentValue"
,"row","comp_row").show()
+-------+----+----+------+------+------------+----------------+---+--------+
|user_id| UB| LB|new_UB|new_LB|currentValue|new_currentValue|row|comp_row|
+-------+----+----+------+------+------------+----------------+---+--------+
| usr001| 7.2|16.8| 7.2| 16.8| 12| 12| 1| 1|
| usr001|12.0|28.0| 12.0| 28.0| 20| 20| 2| 2|
| usr001|10.2|23.8| 12.0| 28.0| 17| 20| 3| 2|
| usr001|12.6|29.4| 12.0| 28.0| 21| 20| 4| 2|
| usr001| 5.4|12.6| 5.4| 12.6| 9| 9| 5| 5|
| usr001|13.8|32.2| 13.8| 32.2| 23| 23| 6| 6|
| usr002| 6.6|15.4| 6.6| 15.4| 11| 11| 1| 1|
| usr002| 6.0|14.0| 6.6| 15.4| 10| 11| 2| 1|
| usr002| 9.0|21.0| 9.0| 21.0| 15| 15| 3| 3|
| usr002| 1.8| 4.2| 1.8| 4.2| 3| 3| 4| 4|
| usr002| 2.4| 5.6| 1.8| 4.2| 4| 3| 5| 4|
+-------+----+----+------+------+------------+----------------+---+--------+
7.The final Step and Boom!!,将 currentValues 与前一行中的新边界进行比较,如果它在边界内,则 updatedValue=new_currentValue 前一行的 else updatedValue=currentValue 同一行。
df7=df6.withColumn("updatedValue",when(col("row")==1,col("currentValue"))\
.when(col("currentValue").between(lag("new_UB",1).over(w),
lag("new_LB",1).over(w)),lag("new_currentValue",1).over(w))
.otherwise(col("currentValue"))).orderBy("user_id")\
.select("user_id","currentValue","UB","LB","updatedValue")
df7:
+-------+------------+----+----+------------+
|user_id|currentValue| UB| LB|updatedValue|
+-------+------------+----+----+------------+
| usr001| 12| 7.2|16.8| 12|
| usr001| 20|12.0|28.0| 20|
| usr001| 17|10.2|23.8| 20|
| usr001| 21|12.6|29.4| 20|
| usr001| 9| 5.4|12.6| 9|
| usr001| 23|13.8|32.2| 23|
| usr002| 11| 6.6|15.4| 11|
| usr002| 10| 6.0|14.0| 11|
| usr002| 15| 9.0|21.0| 11|
| usr002| 3| 1.8| 4.2| 3|
| usr002| 4| 2.4| 5.6| 3|
+-------+------------+----+----+------------+
假设我在 PySpark 中有以下 DF,其中 UB 和 LB 分别代表上限和下限。
+---------+-----+--------------+------+------+ | user_id | row | currentValue | UB | LB | +---------+-----+--------------+------+------+ | usr001 | 1 | 12 | 7.2 | 16.8 | | usr001 | 2 | 20 | 12 | 28 | | usr001 | 3 | 17 | 10.2 | 23.8 | | usr001 | 4 | 21 | 12.6 | 29.4 | | usr001 | 5 | 9 | 5.4 | 12.6 | | usr001 | 6 | 23 | 13.8 | 32.2 | | usr002 | 1 | 11 | 6.6 | 15.4 | | usr002 | 2 | 10 | 6 | 14 | | usr002 | 3 | 15 | 9 | 21 | | usr002 | 4 | 3 | 1.8 | 4.2 | | usr002 | 5 | 4 | 2.4 | 5.6 | +---------+-----+--------------+------+------+
对于 DF 中的每个用户,我想应用一些 logic/rules 以便 currentValue 可以更新为 updatedValue。 Logic/rules如下:
user_id: usr001
- 对于第 1 行:currentValue = updatedValue(对于所有用户)
- 对于第 2 行:如果 currentValue 在第 1 行的 LB 和 UB 之间(如果 20 在 7.2 和 16.8 之间),则第 2 行的 updatedValue 等于第 1 行的 currentValue(第 2 行的 updatedValue = 12)。否则,updatedValue = currentValue (updatedValue = 20)
- 因为在第 2 行中,updatedValue = currentValue,第 3 行与第 2 行进行比较。 对于第 3 行:如果 currentValue 在第 2 行的 LB 和 UB 范围内(如果 17 在 12 和 28 之间),则第 3 行的 updatedValue 等于第 2 行的 currentValue(第 3 行的 updatedValue = 20)。否则,updatedValue = currentValue (updatedValue = 17)
- 因为在第3行中,第3行中的updatedValue =第2行中的currentValue,第4行与第2行进行比较。 对于第 4 行:如果 currentValue 在第 2 行的 LB 和 UB 之间(如果 21 在 12 和 28 之间),则第 4 行的 updatedValue 等于第 2 行的 currentValue(第 4 行的 updatedValue = 20)。否则,updatedValue = currentValue (updatedValue = 21)
- 因为在第4行中,第4行中的updatedValue =第2行中的currentValue,第5行与第2行进行比较。 对于第 5 行:如果 currentValue 在第 2 行的 LB 和 UB 之内(如果 9 在 12 和 28 之间),则第 5 行的 updatedValue 等于第 2 行的 currentValue(第 5 行的 updatedValue = 20)。否则,updatedValue = currentValue (updatedValue = 9)
- 因为在第 5 行中,updatedValue = currentValue,第 6 行与第 5 行进行比较。 对于第 6 行:如果 currentValue 在第 5 行的 LB 和 UB 范围内(如果 23 在 5.4 和 12.6 之间),则第 6 行的 updatedValue 等于第 2 行的 currentValue(第 5 行的 updatedValue = 20)。否则,updatedValue = currentValue (updatedValue = 9)
确切的规则将适用于 usr002。预期输出如下:
+---------+-----+--------------+------+------+--------------+ | user_id | row | currentValue | UB | LB | updatedValue | +---------+-----+--------------+------+------+--------------+ | usr001 | 1 | 12 | 7.2 | 16.8 | 12 | | usr001 | 2 | 20 | 12 | 28 | 20 | | usr001 | 3 | 17 | 10.2 | 23.8 | 20 | | usr001 | 4 | 21 | 12.6 | 29.4 | 20 | | usr001 | 5 | 9 | 5.4 | 12.6 | 9 | | usr001 | 6 | 23 | 13.8 | 32.2 | 23 | | usr002 | 1 | 11 | 6.6 | 15.4 | 11 | | usr002 | 2 | 10 | 6 | 14 | 11 | | usr002 | 3 | 15 | 9 | 21 | 11 | | usr002 | 4 | 3 | 1.8 | 4.2 | 3 | | usr002 | 5 | 4 | 2.4 | 5.6 | 3 | +---------+-----+--------------+------+------+--------------+
有什么方法可以在 Spark 中实现吗?感谢您的帮助!
Spark: 2.4.4
您可以使用 window 功能。但它不是那么简单。这是对代码和逻辑的逐步解释。
(在下面的代码和解释中uv和updatedValue是不一样的)
1.Read df
df=spark.read.csv(path, header=True, inferSchema=True)
2.Specify window
w=Window.partitionBy("user_id").orderBy("row")
3.Create 将当前值与前一行的 UB 和 LB 进行比较的列,如果它在范围内,则 return 前一行 currentValue 否则 return 同一行 currentValue,让将此列称为 "uv"
df2=df.withColumn("uv",when(col("row")==1,col("currentValue"))
.when(col("currentValue").between(lag("UB",1).over(w),
lag("LB",1).over(w)),lag("currentValue",1).over(w))
.otherwise(col("currentValue"))).orderBy("user_id")
df2:
+-------+---+------------+----+----+---+
|user_id|row|currentValue| UB| LB| uv|
+-------+---+------------+----+----+---+
| usr001| 1| 12| 7.2|16.8| 12|
| usr001| 2| 20|12.0|28.0| 20|
| usr001| 3| 17|10.2|23.8| 20|
| usr001| 4| 21|12.6|29.4| 17|
| usr001| 5| 9| 5.4|12.6| 9|
| usr001| 6| 23|13.8|32.2| 23|
| usr002| 1| 11| 6.6|15.4| 11|
| usr002| 2| 10| 6.0|14.0| 11|
| usr002| 3| 15| 9.0|21.0| 15|
| usr002| 4| 3| 1.8| 4.2| 3|
| usr002| 5| 4| 2.4| 5.6| 3|
+-------+---+------------+----+----+---+
4.This 是主要逻辑,根据第 5 行 (usr001) 的逻辑,首先我们必须检查第 4 行更新值是否填充了第 4 行当前值,如果已填充则将第 5 行值与行进行比较4 个边界,否则我们必须转到填充第 4 行 updatedValue 的行并与这些边界进行比较,以在上述步骤中实现此标记所有值,其中 currentValue==uv.
df3=df2.withColumn("comp_row", when(col("currentValue")==col("uv"), col("row")))
df3:
+-------+---+------------+----+----+---+--------+
|user_id|row|currentValue| UB| LB| uv|comp_row|
+-------+---+------------+----+----+---+--------+
| usr001| 1| 12| 7.2|16.8| 12| 1|
| usr001| 2| 20|12.0|28.0| 20| 2|
| usr001| 3| 17|10.2|23.8| 20| null|
| usr001| 4| 21|12.6|29.4| 17| null|
| usr001| 5| 9| 5.4|12.6| 9| 5|
| usr001| 6| 23|13.8|32.2| 23| 6|
| usr002| 1| 11| 6.6|15.4| 11| 1|
| usr002| 2| 10| 6.0|14.0| 11| null|
| usr002| 3| 15| 9.0|21.0| 15| 3|
| usr002| 4| 3| 1.8| 4.2| 3| 4|
| usr002| 5| 4| 2.4| 5.6| 3| null|
+-------+---+------------+----+----+---+--------+
5.Now 如果我们回填每行的空值,我们将得到每行应该与之比较的行号。
df4 = df3.withColumn("comp_row",last("comp_row",True).over(w))
df4:
+-------+---+------------+----+----+---+--------+
|user_id|row|currentValue| UB| LB| uv|comp_row|
+-------+---+------------+----+----+---+--------+
| usr001| 1| 12| 7.2|16.8| 12| 1|
| usr001| 2| 20|12.0|28.0| 20| 2|
| usr001| 3| 17|10.2|23.8| 20| 2|
| usr001| 4| 21|12.6|29.4| 17| 2|
| usr001| 5| 9| 5.4|12.6| 9| 5|
| usr001| 6| 23|13.8|32.2| 23| 6|
| usr002| 1| 11| 6.6|15.4| 11| 1|
| usr002| 2| 10| 6.0|14.0| 11| 1|
| usr002| 3| 15| 9.0|21.0| 15| 3|
| usr002| 4| 3| 1.8| 4.2| 3| 4|
| usr002| 5| 4| 2.4| 5.6| 3| 4|
+-------+---+------------+----+----+---+--------+
注意:comp_row 的值表示下一行应该与哪一行进行比较,例如:第 4 行(usr001)comp_row 包含 2 表示第 5 行与第 2 行进行比较。
6.Now 我们知道哪一行与哪一行比较,我们需要做的只是获取这些行的边界。为此,我们需要将行与 comp_row 连接起来,这样我们就可以在第 4 行中获得第 2 行的边界。
df5 = df4.select("user_id",col("row").alias("comp_row"),
col("UB").alias("new_UB"),col("LB").alias("new_LB")
,col("currentValue").alias("new_currentValue"))
# Note: Here row is selected as comp_row.
df6=df5.join(df4,["user_id","comp_row"],"inner").orderBy("user_id","row")
df6.select("user_id",
"UB","LB"
,"new_UB","new_LB"
,"currentValue","new_currentValue"
,"row","comp_row").show()
+-------+----+----+------+------+------------+----------------+---+--------+
|user_id| UB| LB|new_UB|new_LB|currentValue|new_currentValue|row|comp_row|
+-------+----+----+------+------+------------+----------------+---+--------+
| usr001| 7.2|16.8| 7.2| 16.8| 12| 12| 1| 1|
| usr001|12.0|28.0| 12.0| 28.0| 20| 20| 2| 2|
| usr001|10.2|23.8| 12.0| 28.0| 17| 20| 3| 2|
| usr001|12.6|29.4| 12.0| 28.0| 21| 20| 4| 2|
| usr001| 5.4|12.6| 5.4| 12.6| 9| 9| 5| 5|
| usr001|13.8|32.2| 13.8| 32.2| 23| 23| 6| 6|
| usr002| 6.6|15.4| 6.6| 15.4| 11| 11| 1| 1|
| usr002| 6.0|14.0| 6.6| 15.4| 10| 11| 2| 1|
| usr002| 9.0|21.0| 9.0| 21.0| 15| 15| 3| 3|
| usr002| 1.8| 4.2| 1.8| 4.2| 3| 3| 4| 4|
| usr002| 2.4| 5.6| 1.8| 4.2| 4| 3| 5| 4|
+-------+----+----+------+------+------------+----------------+---+--------+
7.The final Step and Boom!!,将 currentValues 与前一行中的新边界进行比较,如果它在边界内,则 updatedValue=new_currentValue 前一行的 else updatedValue=currentValue 同一行。
df7=df6.withColumn("updatedValue",when(col("row")==1,col("currentValue"))\
.when(col("currentValue").between(lag("new_UB",1).over(w),
lag("new_LB",1).over(w)),lag("new_currentValue",1).over(w))
.otherwise(col("currentValue"))).orderBy("user_id")\
.select("user_id","currentValue","UB","LB","updatedValue")
df7:
+-------+------------+----+----+------------+
|user_id|currentValue| UB| LB|updatedValue|
+-------+------------+----+----+------------+
| usr001| 12| 7.2|16.8| 12|
| usr001| 20|12.0|28.0| 20|
| usr001| 17|10.2|23.8| 20|
| usr001| 21|12.6|29.4| 20|
| usr001| 9| 5.4|12.6| 9|
| usr001| 23|13.8|32.2| 23|
| usr002| 11| 6.6|15.4| 11|
| usr002| 10| 6.0|14.0| 11|
| usr002| 15| 9.0|21.0| 11|
| usr002| 3| 1.8| 4.2| 3|
| usr002| 4| 2.4| 5.6| 3|
+-------+------------+----+----+------------+