Pyspark中如何使用when语句和array_contains根据条件创建新列?

How to use when statement and array_contains in Pyspark to create a new column based on conditions?

我正在尝试使用 filtercase-when 语句和 array_contains 表达式来过滤和标记数据集中的列,并尝试以更高效的方式进行操作比我现在好多了。

我一直无法成功地将这 3 个元素串在一起,希望有人可以提供建议,因为我目前的方法有效但效率不高。

目前:

data = [["a", [1, 2, 3]], ["b", [1, 2, 8]], ["c", [3, 5, 4]], ["d", [8, 1, 4]]]

df = pd.DataFrame(data, columns=["product", "list_of_values"])
sdf = spark.createDataFrame(df)

# partially flag using array_contains to determine if element is within list_of_values
partially_flagged_sdf = (
    sdf.withColumn(
        "contains_element1",
        spark_fns.array_contains(
            sdf.list_of_values, "1"
        ),
    )
    .withColumn(
        "contains_element2",
        spark_fns.array_contains(
            sdf.list_of_values, "2"
        ),
    )
    .withColumn(
        "contains_element3",
        spark_fns.array_contains(
            sdf.list_of_values, "3"
        ),
    )
    .withColumn(
        "contains_element4",
        spark_fns.array_contains(
            sdf.list_of_values, "4"
        ),
    )
)

# using case_when and filtering, add additional flag if product == a, and list_of_values contains element 1 or 2
flagged_sdf = partially_flagged_sdf.withColumn("proda_contains_el1", spark_fns.when((spark_fns.col("product) == 'a') & & (
        (spark_fns.col("contains_element1") == True)
        | (spark_fns.col("contains_element2") == True)
    )),True).otherwise(False)

您可以使用arrays_overlap检查多个元素:

import pyspark.sql.functions as F

df2 = sdf.withColumn(
    'newcol', 
    (F.col('product') == 'a') & 
    F.arrays_overlap('list_of_values', F.array(F.lit(1), F.lit(2)))
)

df2.show()
+-------+--------------+------+
|product|list_of_values|newcol|
+-------+--------------+------+
|      a|     [1, 2, 3]|  true|
|      b|     [1, 2, 8]| false|
|      c|     [3, 5, 4]| false|
|      d|     [8, 1, 4]| false|
+-------+--------------+------+