在 PySpark 数据框 DF 中的用户级别迭代(循环)

Iterate (loop) at user level in PySpark dataframe DF

假设我在 PySpark 中有以下 DF,其中 UB 和 LB 分别代表上限和下限。

+---------+-----+--------------+------+------+
| user_id | row | currentValue |  UB  |  LB  |
+---------+-----+--------------+------+------+
| usr001  |   1 |           12 | 7.2  | 16.8 |
| usr001  |   2 |           20 | 12   | 28   |
| usr001  |   3 |           17 | 10.2 | 23.8 |
| usr001  |   4 |           21 | 12.6 | 29.4 |
| usr001  |   5 |            9 | 5.4  | 12.6 |
| usr001  |   6 |           23 | 13.8 | 32.2 |
| usr002  |   1 |           11 | 6.6  | 15.4 |
| usr002  |   2 |           10 | 6    | 14   |
| usr002  |   3 |           15 | 9    | 21   |
| usr002  |   4 |            3 | 1.8  | 4.2  |
| usr002  |   5 |            4 | 2.4  | 5.6  |
+---------+-----+--------------+------+------+

对于 DF 中的每个用户,我想应用一些 logic/rules 以便 currentValue 可以更新为 updatedValue。 Logic/rules如下:

user_id: usr001

  1. 对于第 1 行:currentValue = updatedValue(对于所有用户)
  2. 对于第 2 行:如果 currentValue 在第 1 行的 LB 和 UB 之间(如果 20 在 7.2 和 16.8 之间),则第 2 行的 updatedValue 等于第 1 行的 currentValue(第 2 行的 updatedValue = 12)。否则,updatedValue = currentValue (updatedValue = 20)
  3. 因为在第 2 行中,updatedValue = currentValue,第 3 行与第 2 行进行比较。 对于第 3 行:如果 currentValue 在第 2 行的 LB 和 UB 范围内(如果 17 在 12 和 28 之间),则第 3 行的 updatedValue 等于第 2 行的 currentValue(第 3 行的 updatedValue = 20)。否则,updatedValue = currentValue (updatedValue = 17)
  4. 因为在第3行中,第3行中的updatedValue =第2行中的currentValue,第4行与第2行进行比较。 对于第 4 行:如果 currentValue 在第 2 行的 LB 和 UB 之间(如果 21 在 12 和 28 之间),则第 4 行的 updatedValue 等于第 2 行的 currentValue(第 4 行的 updatedValue = 20)。否则,updatedValue = currentValue (updatedValue = 21)
  5. 因为在第4行中,第4行中的updatedValue =第2行中的currentValue,第5行与第2行进行比较。 对于第 5 行:如果 currentValue 在第 2 行的 LB 和 UB 之内(如果 9 在 12 和 28 之间),则第 5 行的 updatedValue 等于第 2 行的 currentValue(第 5 行的 updatedValue = 20)。否则,updatedValue = currentValue (updatedValue = 9)
  6. 因为在第 5 行中,updatedValue = currentValue,第 6 行与第 5 行进行比较。 对于第 6 行:如果 currentValue 在第 5 行的 LB 和 UB 范围内(如果 23 在 5.4 和 12.6 之间),则第 6 行的 updatedValue 等于第 2 行的 currentValue(第 5 行的 updatedValue = 20)。否则,updatedValue = currentValue (updatedValue = 9)

确切的规则将适用于 usr002。预期输出如下:

+---------+-----+--------------+------+------+--------------+
| user_id | row | currentValue |  UB  |  LB  | updatedValue |
+---------+-----+--------------+------+------+--------------+
| usr001  |   1 |           12 | 7.2  | 16.8 |           12 |
| usr001  |   2 |           20 | 12   | 28   |           20 |
| usr001  |   3 |           17 | 10.2 | 23.8 |           20 |
| usr001  |   4 |           21 | 12.6 | 29.4 |           20 |
| usr001  |   5 |            9 | 5.4  | 12.6 |            9 |
| usr001  |   6 |           23 | 13.8 | 32.2 |           23 |
| usr002  |   1 |           11 | 6.6  | 15.4 |           11 |
| usr002  |   2 |           10 | 6    | 14   |           11 |
| usr002  |   3 |           15 | 9    | 21   |           11 |
| usr002  |   4 |            3 | 1.8  | 4.2  |            3 |
| usr002  |   5 |            4 | 2.4  | 5.6  |            3 |
+---------+-----+--------------+------+------+--------------+

有什么方法可以在 Spark 中实现吗?感谢您的帮助!

Spark: 2.4.4

您可以使用 window 功能。但它不是那么简单。这是对代码和逻辑的逐步解释。

(在下面的代码和解释中uv和updatedValue是不一样的)

1.Read df

df=spark.read.csv(path, header=True, inferSchema=True)

2.Specify window

w=Window.partitionBy("user_id").orderBy("row")

3.Create 将当前值与前一行的 UB 和 LB 进行比较的列,如果它在范围内,则 return 前一行 currentValue 否则 return 同一行 currentValue,让将此列称为 "uv"

df2=df.withColumn("uv",when(col("row")==1,col("currentValue"))
.when(col("currentValue").between(lag("UB",1).over(w),                                          
lag("LB",1).over(w)),lag("currentValue",1).over(w))
.otherwise(col("currentValue"))).orderBy("user_id")

df2:

+-------+---+------------+----+----+---+
|user_id|row|currentValue|  UB|  LB| uv|
+-------+---+------------+----+----+---+
| usr001|  1|          12| 7.2|16.8| 12|
| usr001|  2|          20|12.0|28.0| 20|
| usr001|  3|          17|10.2|23.8| 20|
| usr001|  4|          21|12.6|29.4| 17|
| usr001|  5|           9| 5.4|12.6|  9|
| usr001|  6|          23|13.8|32.2| 23|
| usr002|  1|          11| 6.6|15.4| 11|
| usr002|  2|          10| 6.0|14.0| 11|
| usr002|  3|          15| 9.0|21.0| 15|
| usr002|  4|           3| 1.8| 4.2|  3|
| usr002|  5|           4| 2.4| 5.6|  3|
+-------+---+------------+----+----+---+

4.This 是主要逻辑,根据第 5 行 (usr001) 的逻辑,首先我们必须检查第 4 行更新值是否填充了第 4 行当前值,如果已填充则将第 5 行值与行进行比较4 个边界,否则我们必须转到填充第 4 行 updatedValue 的行并与这些边界进行比较,以在上述步骤中实现此标记所有值,其中 currentValue==uv.

df3=df2.withColumn("comp_row", when(col("currentValue")==col("uv"), col("row")))

df3:

+-------+---+------------+----+----+---+--------+
|user_id|row|currentValue|  UB|  LB| uv|comp_row|
+-------+---+------------+----+----+---+--------+
| usr001|  1|          12| 7.2|16.8| 12|       1|
| usr001|  2|          20|12.0|28.0| 20|       2|
| usr001|  3|          17|10.2|23.8| 20|    null|
| usr001|  4|          21|12.6|29.4| 17|    null|
| usr001|  5|           9| 5.4|12.6|  9|       5|
| usr001|  6|          23|13.8|32.2| 23|       6|
| usr002|  1|          11| 6.6|15.4| 11|       1|
| usr002|  2|          10| 6.0|14.0| 11|    null|
| usr002|  3|          15| 9.0|21.0| 15|       3|
| usr002|  4|           3| 1.8| 4.2|  3|       4|
| usr002|  5|           4| 2.4| 5.6|  3|    null|
+-------+---+------------+----+----+---+--------+

5.Now 如果我们回填每行的空值,我们将得到每行应该与之比较的行号。

df4 = df3.withColumn("comp_row",last("comp_row",True).over(w))

df4:

+-------+---+------------+----+----+---+--------+
|user_id|row|currentValue|  UB|  LB| uv|comp_row|
+-------+---+------------+----+----+---+--------+
| usr001|  1|          12| 7.2|16.8| 12|       1|
| usr001|  2|          20|12.0|28.0| 20|       2|
| usr001|  3|          17|10.2|23.8| 20|       2|
| usr001|  4|          21|12.6|29.4| 17|       2|
| usr001|  5|           9| 5.4|12.6|  9|       5|
| usr001|  6|          23|13.8|32.2| 23|       6|
| usr002|  1|          11| 6.6|15.4| 11|       1|
| usr002|  2|          10| 6.0|14.0| 11|       1|
| usr002|  3|          15| 9.0|21.0| 15|       3|
| usr002|  4|           3| 1.8| 4.2|  3|       4|
| usr002|  5|           4| 2.4| 5.6|  3|       4|
+-------+---+------------+----+----+---+--------+

注意:comp_row 的值表示下一行应该与哪一行进行比较,例如:第 4 行(usr001)comp_row 包含 2 表示第 5 行与第 2 行进行比较。

6.Now 我们知道哪一行与哪一行比较,我们需要做的只是获取这些行的边界。为此,我们需要将行与 comp_row 连接起来,这样我们就可以在第 4 行中获得第 2 行的边界。

df5 = df4.select("user_id",col("row").alias("comp_row"),
col("UB").alias("new_UB"),col("LB").alias("new_LB")
,col("currentValue").alias("new_currentValue")) 
# Note: Here row is selected as comp_row.

df6=df5.join(df4,["user_id","comp_row"],"inner").orderBy("user_id","row")

df6.select("user_id",
"UB","LB"
,"new_UB","new_LB"
,"currentValue","new_currentValue"
,"row","comp_row").show()

+-------+----+----+------+------+------------+----------------+---+--------+
|user_id|  UB|  LB|new_UB|new_LB|currentValue|new_currentValue|row|comp_row|
+-------+----+----+------+------+------------+----------------+---+--------+
| usr001| 7.2|16.8|   7.2|  16.8|          12|              12|  1|       1|
| usr001|12.0|28.0|  12.0|  28.0|          20|              20|  2|       2|
| usr001|10.2|23.8|  12.0|  28.0|          17|              20|  3|       2|
| usr001|12.6|29.4|  12.0|  28.0|          21|              20|  4|       2|
| usr001| 5.4|12.6|   5.4|  12.6|           9|               9|  5|       5|
| usr001|13.8|32.2|  13.8|  32.2|          23|              23|  6|       6|
| usr002| 6.6|15.4|   6.6|  15.4|          11|              11|  1|       1|
| usr002| 6.0|14.0|   6.6|  15.4|          10|              11|  2|       1|
| usr002| 9.0|21.0|   9.0|  21.0|          15|              15|  3|       3|
| usr002| 1.8| 4.2|   1.8|   4.2|           3|               3|  4|       4|
| usr002| 2.4| 5.6|   1.8|   4.2|           4|               3|  5|       4|
+-------+----+----+------+------+------------+----------------+---+--------+

7.The final Step and Boom!!,将 currentValues 与前一行中的新边界进行比较,如果它在边界内,则 updatedValue=new_currentValue 前一行的 else updatedValue=currentValue 同一行。

df7=df6.withColumn("updatedValue",when(col("row")==1,col("currentValue"))\
.when(col("currentValue").between(lag("new_UB",1).over(w),                                               
lag("new_LB",1).over(w)),lag("new_currentValue",1).over(w))             
.otherwise(col("currentValue"))).orderBy("user_id")\
.select("user_id","currentValue","UB","LB","updatedValue")

df7:

+-------+------------+----+----+------------+
|user_id|currentValue|  UB|  LB|updatedValue|
+-------+------------+----+----+------------+
| usr001|          12| 7.2|16.8|          12|
| usr001|          20|12.0|28.0|          20|
| usr001|          17|10.2|23.8|          20|
| usr001|          21|12.6|29.4|          20|
| usr001|           9| 5.4|12.6|           9|
| usr001|          23|13.8|32.2|          23|
| usr002|          11| 6.6|15.4|          11|
| usr002|          10| 6.0|14.0|          11|
| usr002|          15| 9.0|21.0|          11|
| usr002|           3| 1.8| 4.2|           3|
| usr002|           4| 2.4| 5.6|           3|
+-------+------------+----+----+------------+