基于多个值的 pyspark 数据帧数组时如何处理

How to case when pyspark dataframe array based on multiple values

我可以使用 array_contains 检查数组是否包含值。

test = test.withColumn("my_boolean", 
    F.when(expr("array_contains('check_variable', 'a')"),
           1)
 .otherwise(0))

我不想测试一个值,而是想测试多个值。我可以嵌套:

test = test.withColumn("my_boolean", 
    F.when(expr("array_contains('check_variable', 'a')"),
           1)
    F.when(expr("array_contains('check_variable', 'b')"),
           1)
 .otherwise(0))

有没有办法在一条语句中做到这一点,伪代码:

test = test.withColumn("my_boolean", 
    F.when(expr("array_contains('check_variable', ['a','b'])"),
           1)
 .otherwise(0))

您可以在两个数组上使用 array_intersect,如果交集大于 0,则数组中至少有一个值。

示例:

spark = SparkSession.builder.getOrCreate()
data = [
    {"id": 1, "test": ["A", "B"]},
    {"id": 2, "test": ["E", "C"]},
    {"id": 3, "test": ["D", "B"]},
]
df = spark.createDataFrame(data)
df = df.withColumn(
    "result",
    F.when(
        F.size(F.array_intersect(F.col("test"), F.array(F.lit("A"), F.lit("B"))))
        > 0,
        1,
    ).otherwise(0),
)

结果:

+---+------+------+                                                             
| id|  test|result|
+---+------+------+
|  1|[A, B]|     1|
|  2|[E, C]|     0|
|  3|[D, B]|     1|
+---+------+------+

如果你想使用expr:

F.when(
    F.expr("size(array_intersect(test, array('A', 'B'))) > 0"),
    1,
).otherwise(0)

spark>=2.4 中,您可以使用 array_intersect 并检查输出的大小是否与您要查找的值的数量相同(在您的示例中为 2)。

pyspark.sql.functions.array_intersect(col1, col2)

Collection function: returns an array of the elements in the intersection of col1 and col2, without duplicates.

代码可以如下:

test = test.withColumn("my_boolean",
    f.expr("size(array_intersect(check_variable, array(a, b))) > 0").cast("int"))

请注意,将布尔值转换为 0/1 值的另一种方法是将其转换为 int。

如果有人对 spark<2.4 解决方案感兴趣,可以构造一个基于 array_contains 的函数并迭代列数组:

from functools import reduce
def contains_at_least_one(a):
    contains = map(lambda v: f.array_contains('check_variable', f.col(v)), a)
    return reduce(lambda x, y: x | y, contains)

test = test.withColumn("my_boolean",
    contains_at_least_one(['a', 'b']).cast("int"))

您要找的函数叫做 arrays_overlap:

arrays_overlap(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2

import pyspark.sql.functions as F

test = test.withColumn(
        "my_boolean",
        F.expr("arrays_overlap(check_variable, array('a','b'))").cast("int")
)