如何在 pyspark 中的带有数组的列上使用列表理解?

How to use list comprehension on a column with array in pyspark?

我有一个看起来像这样的 pyspark 数据框。

+--------------------+-------+--------------------+
|              ID    |country|               attrs|
+--------------------+-------+--------------------+
|ffae10af            |     US|[1,2,3,4...]        |
|3de27656            |     US|[1,7,2,4...]        |
|75ce4e58            |     US|[1,2,1,4...]        |
|908df65c            |     US|[1,8,3,0...]        |
|f0503257            |     US|[1,2,3,2...]        |
|2tBxD6j             |     US|[1,2,3,4...]        |
|33811685            |     US|[1,5,3,5...]        |
|aad21639            |     US|[7,8,9,4...]        |
|e3d9e3bb            |     US|[1,10,9,4...]       |
|463f6f69            |     US|[12,2,13,4...]      |
+--------------------+-------+--------------------+

我也有一套是这样的

reference_set = (1,2,100,500,821)

我想做的是创建一个新列表作为数据框中的一列,可能使用这样的列表理解 [attr for attr in attrs if attr in reference_set]

所以我的最终数据框应该是这样的

+--------------------+-------+--------------------+
|              ID    |country|      filtered_attrs|
+--------------------+-------+--------------------+
|ffae10af            |     US|[1,2]               |
|3de27656            |     US|[1,2]               |
|75ce4e58            |     US|[1,2]               |
|908df65c            |     US|[1]                 |
|f0503257            |     US|[1,2]               |
|2tBxD6j             |     US|[1,2]               |
|33811685            |     US|[1]                 |
|aad21639            |     US|[]                  |
|e3d9e3bb            |     US|[1]                 |
|463f6f69            |     US|[2]                 |
+--------------------+-------+--------------------+

我该怎么做?因为我是 pyspark 的新手,所以我想不出逻辑。

编辑:在下面发布了一个逻辑,如果有更有效的方法,请告诉我。

我设法将过滤函数与 UDF 结合使用来完成这项工作。

def filter_items(item):
    if item in reference_set:
        return True
    else:
        return False

custom_udf = udf(lambda attributes : list(filter(filter_items, attributes)))
processed_df = df.withColumn('filtered_attrs',custom_udf(col('attrs')))

这给了我所需的输出

您可以使用内置函数 - array_intersect

# Sample dataframe

df = spark.createDataFrame([('ffae10af', 'US', [1,2,3,4])], ["ID", "Country", "attrs"])

reference_set = {1,2,100,500,821}

# This step is to add set as column in dataframe
set_to_string = ",".join([str(x) for x in reference_set])

df.withColumn('reference_set', split(lit(set_to_string), ',').cast('array<bigint>')). \
withColumn('filtered_attrs', array_intersect('attrs','reference_set'))\ 
.show(truncate = False)

+--------+-------+------------+---------------------+--------------+
|ID      |Country|attrs       |reference_set        |filtered_attrs|
+--------+-------+------------+---------------------+--------------+
|ffae10af|US     |[1, 2, 3, 4]|[1, 2, 100, 500, 821]|[1, 2]        |
+--------+-------+------------+---------------------+--------------+