如何在 pyspark 的另一列上过滤满足两个条件的 ID?

How to filter IDs which meet two conditions over another column in pyspark?

我有一个 table 看起来像这样:

id country count count_1
A36992434 MX 1 2
A36992434 ES 1 2
A00749707 ES 1 2
A00749707 MX 1 2
A10352704 PE 1 2
A10352704 ES 1 2

我想保留国家/地区列采用值 ES 和 MX 的 ID。因此,在这种情况下,我希望获得显示以下内容的输出:

id country count count_1
A36992434 MX 1 2
A36992434 ES 1 2
A00749707 ES 1 2
A00749707 MX 1 2

非常感谢!

您可以创建一个 countryAgg 数据框,它将包含 MXES 的标志,方法是在 id 级别聚合它并进一步用 [= 标记它20=] 检查两个国家

并进一步利用 filter 仅过滤包含 ESMXids,如下所示 -

数据准备

s = StringIO("""
id  country count   count_1
A36992434   MX  1   2
A36992434   ES  1   2
A00749707   ES  1   2
A00749707   MX  1   2
A10352704   PE  1   2
A10352704   ES  1   2
""")

df = pd.read_csv(s,delimiter='\t')

sparkDF = sql.createDataFrame(df)

sparkDF.show()
+---------+-------+-----+-------+
|       id|country|count|count_1|
+---------+-------+-----+-------+
|A36992434|     MX|    1|      2|
|A36992434|     ES|    1|      2|
|A00749707|     ES|    1|      2|
|A00749707|     MX|    1|      2|
|A10352704|     PE|    1|      2|
|A10352704|     ES|    1|      2|
+---------+-------+-----+-------+

数组重叠标记

countryAgg = sparkDF.groupBy(F.col('id')).agg(F.collect_set(F.col('country')).alias('country_set'))

countryAgg = countryAgg.withColumn('country_check_mx',F.array(F.lit('MX')))\
                        .withColumn('country_check_es',F.array(F.lit('ES')))\
                        .withColumn("overlap_flag_mx"
                                    ,F.arrays_overlap(F.col("country_set"),F.col("country_check_mx"))
                                             
                                   )\
                        .withColumn("overlap_flag_es"
                                    ,F.arrays_overlap(F.col("country_set"),F.col("country_check_es"))
                                             
                                   )

countryAgg.show()

+---------+-----------+----------------+----------------+---------------+---------------+
|       id|country_set|country_check_mx|country_check_es|overlap_flag_mx|overlap_flag_es|
+---------+-----------+----------------+----------------+---------------+---------------+
|A36992434|   [MX, ES]|            [MX]|            [ES]|           true|           true|
|A00749707|   [ES, MX]|            [MX]|            [ES]|           true|           true|
|A10352704|   [ES, PE]|            [MX]|            [ES]|          false|           true|
+---------+-----------+----------------+----------------+---------------+---------------+

加入

countryAgg = countryAgg.filter((F.col('overlap_flag_mx') & F.col('overlap_flag_es')))

sparkDF.join(countryAgg
             ,sparkDF['id'] == countryAgg['id']
             ,'inner'
).select(sparkDF['*'])\
.show()

+---------+-------+-----+-------+
|       id|country|count|count_1|
+---------+-------+-----+-------+
|A36992434|     MX|    1|      2|
|A36992434|     ES|    1|      2|
|A00749707|     ES|    1|      2|
|A00749707|     MX|    1|      2|
+---------+-------+-----+-------+