根据 pyspark 数组列中的多个字符串进行过滤

Filter on the basis of multiple strings in a pyspark array column

df = sqlContext.createDataFrame(pd.DataFrame([('id1', ['a','b','c']),
                                              ('id2', ['a','b','d']),
                                              ('id3', ['c','e','f'])], 
                                              columns=['id', 'items']))

from pyspark.sql.functions import udf, col, when
from pyspark.sql.types import ArrayType, StringType, IntegerType

filter_array_udf = udf(lambda arr: [1 if (x =='a' and x=='b') else 0 for x in arr], "int")

df2 = df.withColumn("listed1", filter_array_udf(col("items")))
df2.show(20,0)

如果某个 id 包含 'a' 或 'b' 字符串,我会尝试标记该行。其中 udf returns 为空值。我对udfs很陌生。我必须在给定的 udf 中更改什么才能获得所需的结果

df.filter(F.array_contains(F.col('items'),'a')).show()

这仅适用于单个字符串,但如果我在数组中传递 ['a', 'b']。它抛出错误

Unsupported literal type class java.util.ArrayList [a, b]

我会使用 lit(v1 and v2)。 df.select 将 return 布尔值。所以如果需要显示df,使用filter

from pyspark.sql.functions import *
df.filter(array_contains(df.items,  lit('a'and 'b'))).show()

使用array_intersect检查数组列中的元素。

from pyspark.sql import functions as f

df.withColumn('temp', f.array_intersect(f.col('items'), f.array(f.lit('a'), f.lit('b')))) \
  .withColumn('listed1', f.expr('if(temp != array(), true, false)')) \
  .show(10, False)

+---+---------+------+-------+
|id |items    |temp  |listed1|
+---+---------+------+-------+
|id1|[a, b, c]|[a, b]|true   |
|id2|[a, b, d]|[a, b]|true   |
|id3|[c, e, f]|[]    |false  |
+---+---------+------+-------+