如何在 pyspark 的另一列上过滤满足两个条件的 ID?
How to filter IDs which meet two conditions over another column in pyspark?
我有一个 table 看起来像这样:
id
country
count
count_1
A36992434
MX
1
2
A36992434
ES
1
2
A00749707
ES
1
2
A00749707
MX
1
2
A10352704
PE
1
2
A10352704
ES
1
2
我想保留国家/地区列采用值 ES 和 MX 的 ID。因此,在这种情况下,我希望获得显示以下内容的输出:
id
country
count
count_1
A36992434
MX
1
2
A36992434
ES
1
2
A00749707
ES
1
2
A00749707
MX
1
2
非常感谢!
您可以创建一个 countryAgg
数据框,它将包含 MX
和 ES
的标志,方法是在 id
级别聚合它并进一步用 [= 标记它20=] 检查两个国家
并进一步利用 filter 仅过滤包含 ES
和 MX
的 ids
,如下所示 -
数据准备
s = StringIO("""
id country count count_1
A36992434 MX 1 2
A36992434 ES 1 2
A00749707 ES 1 2
A00749707 MX 1 2
A10352704 PE 1 2
A10352704 ES 1 2
""")
df = pd.read_csv(s,delimiter='\t')
sparkDF = sql.createDataFrame(df)
sparkDF.show()
+---------+-------+-----+-------+
| id|country|count|count_1|
+---------+-------+-----+-------+
|A36992434| MX| 1| 2|
|A36992434| ES| 1| 2|
|A00749707| ES| 1| 2|
|A00749707| MX| 1| 2|
|A10352704| PE| 1| 2|
|A10352704| ES| 1| 2|
+---------+-------+-----+-------+
数组重叠标记
countryAgg = sparkDF.groupBy(F.col('id')).agg(F.collect_set(F.col('country')).alias('country_set'))
countryAgg = countryAgg.withColumn('country_check_mx',F.array(F.lit('MX')))\
.withColumn('country_check_es',F.array(F.lit('ES')))\
.withColumn("overlap_flag_mx"
,F.arrays_overlap(F.col("country_set"),F.col("country_check_mx"))
)\
.withColumn("overlap_flag_es"
,F.arrays_overlap(F.col("country_set"),F.col("country_check_es"))
)
countryAgg.show()
+---------+-----------+----------------+----------------+---------------+---------------+
| id|country_set|country_check_mx|country_check_es|overlap_flag_mx|overlap_flag_es|
+---------+-----------+----------------+----------------+---------------+---------------+
|A36992434| [MX, ES]| [MX]| [ES]| true| true|
|A00749707| [ES, MX]| [MX]| [ES]| true| true|
|A10352704| [ES, PE]| [MX]| [ES]| false| true|
+---------+-----------+----------------+----------------+---------------+---------------+
加入
countryAgg = countryAgg.filter((F.col('overlap_flag_mx') & F.col('overlap_flag_es')))
sparkDF.join(countryAgg
,sparkDF['id'] == countryAgg['id']
,'inner'
).select(sparkDF['*'])\
.show()
+---------+-------+-----+-------+
| id|country|count|count_1|
+---------+-------+-----+-------+
|A36992434| MX| 1| 2|
|A36992434| ES| 1| 2|
|A00749707| ES| 1| 2|
|A00749707| MX| 1| 2|
+---------+-------+-----+-------+
我有一个 table 看起来像这样:
id | country | count | count_1 |
---|---|---|---|
A36992434 | MX | 1 | 2 |
A36992434 | ES | 1 | 2 |
A00749707 | ES | 1 | 2 |
A00749707 | MX | 1 | 2 |
A10352704 | PE | 1 | 2 |
A10352704 | ES | 1 | 2 |
我想保留国家/地区列采用值 ES 和 MX 的 ID。因此,在这种情况下,我希望获得显示以下内容的输出:
id | country | count | count_1 |
---|---|---|---|
A36992434 | MX | 1 | 2 |
A36992434 | ES | 1 | 2 |
A00749707 | ES | 1 | 2 |
A00749707 | MX | 1 | 2 |
非常感谢!
您可以创建一个 countryAgg
数据框,它将包含 MX
和 ES
的标志,方法是在 id
级别聚合它并进一步用 [= 标记它20=] 检查两个国家
并进一步利用 filter 仅过滤包含 ES
和 MX
的 ids
,如下所示 -
数据准备
s = StringIO("""
id country count count_1
A36992434 MX 1 2
A36992434 ES 1 2
A00749707 ES 1 2
A00749707 MX 1 2
A10352704 PE 1 2
A10352704 ES 1 2
""")
df = pd.read_csv(s,delimiter='\t')
sparkDF = sql.createDataFrame(df)
sparkDF.show()
+---------+-------+-----+-------+
| id|country|count|count_1|
+---------+-------+-----+-------+
|A36992434| MX| 1| 2|
|A36992434| ES| 1| 2|
|A00749707| ES| 1| 2|
|A00749707| MX| 1| 2|
|A10352704| PE| 1| 2|
|A10352704| ES| 1| 2|
+---------+-------+-----+-------+
数组重叠标记
countryAgg = sparkDF.groupBy(F.col('id')).agg(F.collect_set(F.col('country')).alias('country_set'))
countryAgg = countryAgg.withColumn('country_check_mx',F.array(F.lit('MX')))\
.withColumn('country_check_es',F.array(F.lit('ES')))\
.withColumn("overlap_flag_mx"
,F.arrays_overlap(F.col("country_set"),F.col("country_check_mx"))
)\
.withColumn("overlap_flag_es"
,F.arrays_overlap(F.col("country_set"),F.col("country_check_es"))
)
countryAgg.show()
+---------+-----------+----------------+----------------+---------------+---------------+
| id|country_set|country_check_mx|country_check_es|overlap_flag_mx|overlap_flag_es|
+---------+-----------+----------------+----------------+---------------+---------------+
|A36992434| [MX, ES]| [MX]| [ES]| true| true|
|A00749707| [ES, MX]| [MX]| [ES]| true| true|
|A10352704| [ES, PE]| [MX]| [ES]| false| true|
+---------+-----------+----------------+----------------+---------------+---------------+
加入
countryAgg = countryAgg.filter((F.col('overlap_flag_mx') & F.col('overlap_flag_es')))
sparkDF.join(countryAgg
,sparkDF['id'] == countryAgg['id']
,'inner'
).select(sparkDF['*'])\
.show()
+---------+-------+-----+-------+
| id|country|count|count_1|
+---------+-------+-----+-------+
|A36992434| MX| 1| 2|
|A36992434| ES| 1| 2|
|A00749707| ES| 1| 2|
|A00749707| MX| 1| 2|
+---------+-------+-----+-------+