PySpark dataframe 如何过滤数据?

PySpark dataframe How to filter data?

我有一个数据框,其中包含部门、项目 ID 和这些 ID 的数量。有 127 个部门,我想获取每个部门的前 10 个项目并列出它们。这意味着基于项目计数,我想分别列出每个部门的前 10 个项目。我一直在尝试使用 groupBy 和 agg.max 来做到这一点,但没有成功。下面列出了数据框的示例。

Department Item id count
A 101 10
B 102 5
A 104 12
C 101 5
D 104 14
C 108 10

解决方案基于row_number() windows函数。

  • 在此演示中,我返回了每个部门的前 3 条记录。随意 将其更改为 10。
  • qualify 是 Spark SQL 的新手。如果您的 Spark 版本不支持它,则需要包装查询,过滤器将在外部查询上使用 WHERE 子句完成。
  • 我将 Item id 添加到 ORDER BY 以便确定地打破 count 关系。

数据样本创建

df = spark.sql('''select char(ascii('A') + d.i) as Department, 100 + i.i as `Item id`, int(rand()*100) as count from range(3) as d(i), range(7) as i(i) order by 1,3 desc''')

df.show(999)

+----------+-------+-----+
|Department|Item id|count|
+----------+-------+-----+
|         A|    103|   89|
|         A|    106|   68|
|         A|    104|   54|
|         A|    100|   52|
|         A|    105|   50|
|         A|    102|   40|
|         A|    101|   30|
|         B|    104|   94|
|         B|    101|   87|
|         B|    106|   74|
|         B|    105|   66|
|         B|    102|   48|
|         B|    100|   32|
|         B|    103|   14|
|         C|    105|   95|
|         C|    103|   94|
|         C|    102|   90|
|         C|    104|   82|
|         C|    100|    9|
|         C|    101|    6|
|         C|    106|    3|
+----------+-------+-----+

Spark SQL 解决方案

df.createOrReplaceTempView('t')

sql_query = '''
select  *
from    t
qualify row_number() over (partition by Department order by count desc, `Item id`) <= 3
'''

spark.sql(sql_query).show(999)
  
+----------+-------+-----+
|Department|Item id|count|
+----------+-------+-----+
|         A|    103|   89|
|         A|    106|   68|
|         A|    104|   54|
|         B|    104|   94|
|         B|    101|   87|
|         B|    106|   74|
|         C|    105|   95|
|         C|    103|   94|
|         C|    102|   90|
+----------+-------+-----+

pyspark 解决方案

import pyspark.sql.functions as F
import pyspark.sql.window as W

(df.withColumn('rn', F.row_number().over(W.Window.partitionBy('Department').orderBy(df['count'].desc(),df['Item id'])))
 .where('rn <= 3')
 .drop('rn')
 .show(999)
)

+----------+-------+-----+
|Department|Item id|count|
+----------+-------+-----+
|         A|    103|   89|
|         A|    106|   68|
|         A|    104|   54|
|         B|    104|   94|
|         B|    101|   87|
|         B|    106|   74|
|         C|    105|   95|
|         C|    103|   94|
|         C|    102|   90|
+----------+-------+-----+