根据 PySpark 中另一个数据框的列值的值更新列中的值

Update values in a column based on values of another data frame's column values in PySpark

我在 PySpark 中有两个数据框:df1

+---+-----------------+
|id1|           items1|
+---+-----------------+
|  0|     [B, C, D, E]|
|  1|        [E, A, C]|
|  2|     [F, A, E, B]|
|  3|        [E, G, A]|
|  4|  [A, C, E, B, D]|
+---+-----------------+ 

df2

+---+-----------------+
|id2|           items2|
+---+-----------------+
|001|              [B]|
|002|              [A]|
|003|              [C]|
|004|              [E]|
+---+-----------------+ 

我想在 df1 中创建一个新列来更新值 items1 列,因此它只保留在 df2 中也出现(在 items2 的任何行中)的值。结果应如下所示:

+---+-----------------+----------------------+
|id1|           items1|        items1_updated|
+---+-----------------+----------------------+
|  0|     [B, C, D, E]|             [B, C, E]|
|  1|        [E, A, C]|             [E, A, C]|
|  2|     [F, A, E, B]|             [A, E, B]|
|  3|        [E, G, A]|                [E, A]|
|  4|  [A, C, E, B, D]|          [A, C, E, B]|
+---+-----------------+----------------------+

我通常会使用 collect() 来获取 items2 列中所有值的列表,然后使用应用于 items1 中每一行的 udf 来获取交集。但是数据非常大(超过 1000 万行),我无法使用 collect() 来获取这样的列表。有没有办法在以数据帧格式保存数据的同时做到这一点?或者不使用 collect() 的其他方式?

您要做的第一件事是 explode df2.items2 中的值,这样数组的内容将在不同的行上:

from pyspark.sql.functions import explode
df2 = df2.select(explode("items2").alias("items2"))
df2.show()
#+------+
#|items2|
#+------+
#|     B|
#|     A|
#|     C|
#|     E|
#+------+

(假设 df2.items2 中的值是不同的 - 如果不是,您需要添加 df2 = df2.distinct()。)

选项 1:使用 crossJoin

现在您可以 crossJoin 将新的 df2 返回到 df1 并仅保留 df1.items1 包含 df2.items2 中元素的行。我们可以使用 pyspark.sql.functions.array_contains and that allows us to 来实现这一点。

筛选后,按 id1items1 分组并使用 pyspark.sql.functions.collect_list

聚合
from pyspark.sql.functions import expr, collect_list

df1.alias("l").crossJoin(df2.alias("r"))\
    .where(expr("array_contains(l.items1, r.items2)"))\
    .groupBy("l.id1", "l.items1")\
    .agg(collect_list("r.items2").alias("items1_updated"))\
    .show()
#+---+---------------+--------------+
#|id1|         items1|items1_updated|
#+---+---------------+--------------+
#|  1|      [E, A, C]|     [A, C, E]|
#|  0|   [B, C, D, E]|     [B, C, E]|
#|  4|[A, C, E, B, D]|  [B, A, C, E]|
#|  3|      [E, G, A]|        [A, E]|
#|  2|   [F, A, E, B]|     [B, A, E]|
#+---+---------------+--------------+

选项 2:分解 df1.items1 并左连接:

另一种选择是 explode items1 中的内容 df1 并进行左连接。 join之后,我们要做一个和上面类似的group by和aggregation。这是有效的,因为 collect_list 将忽略由不匹配的行

引入的 null
df1.withColumn("items1", explode("items1")).alias("l")\
    .join(df2.alias("r"), on=expr("l.items1=r.items2"), how="left")\
    .groupBy("l.id1")\
    .agg(
        collect_list("l.items1").alias("items1"),
        collect_list("r.items2").alias("items1_updated")
    ).show()
#+---+---------------+--------------+
#|id1|         items1|items1_updated|
#+---+---------------+--------------+
#|  0|   [E, B, D, C]|     [E, B, C]|
#|  1|      [E, C, A]|     [E, C, A]|
#|  3|      [E, A, G]|        [E, A]|
#|  2|   [F, E, B, A]|     [E, B, A]|
#|  4|[E, B, D, C, A]|  [E, B, C, A]|
#+---+---------------+--------------+