过滤嵌套数组列并创建新的嵌套数组列

Filter a nested array column and create new nested array column

我有一个下面的示例数据框,我需要根据字段 colB 的内容过滤 colA

Schema for the Input

 |-- colA: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- id: string (nullable = true)
 |    |    |-- type: string (nullable = true)
 |-- colB: array (nullable = true)
 |    |-- element: string (containsNull = true)

| colA                                  | colB           |
| ------------------------------------- | -------------- |
| [{ABC, Completed}, {DEF, Pending}]    | [ABC, GHI]     |
| [{ABC, Completed}, {GHI, Failure}]    | [ABC, GHI]     |
| [{ABC, Completed}, {DEF, Pending}]    | [ABC]          |

所以,寻找下面的输出

| colA                              | colB       | colC   
| ----------------------------------| -----------| ------
| [{ABC, Completed}, {DEF, Pending}]| [ABC, GHI] | [{ABC, Completed}]
| [{ABC, Completed}, {GHI, Failure}]| [ABC, GHI] | [{ABC, Completed}, {GHI, Failure}]
| [{ABC, Completed}, {DEF, Pending}]| [ABC]      | [{ABC, Completed}]

当 colB 是字符串时,我能够使用高阶函数找出逻辑。下面是它的代码片段,需要帮助将其扩展到 colB 是字符串数组时

inputDF
      .withColumn(
        "colC",
        expr(
          "filter(colA, colA_struct -> colA_struct.id == colB)"
        )
      )

array_contains函数判断colB是否包含colA.id,然后用filter函数过滤colA,即可以得到 colC.

import pyspark.sql.functions as F
......
df = df.withColumn('colC', F.expr('filter(colA, x -> array_contains(colB, x.id))'))
df.show(truncate=False)