在pyspark中对数据进行分组,并获取每组中的topn数据

group data in pyspark and get the topn data in each group

我有一个数据,可以简单的表示为:

conf = SparkConf().setMaster("local[*]").setAppName("test")
sc = SparkContext(conf=conf).getOrCreate()
spark = SparkSession(sparkContext=sc).builder.getOrCreate()

rdd = sc.parallelize([(1, 10), (3, 11), (1, 8), (1, 12), (3, 7), (3, 9)])
data = spark.createDataFrame(rdd, ['x', 'y'])
data.show()

def f(x):
    y = sorted(x, reverse=True)[:2]
    return y

h_f = udf(f, IntegerType())
h_f = spark.udf.register("h_f", h_f)
data.groupBy('x').agg({"y": h_f}).show()

但是出错了:AttributeError: 'function' object has no attribute '_get_object_id', 如何获取每个组中的topn项?

考虑到您正在寻找属于每个 'x' 组的前 n 'y' 个元素。

from pyspark.sql import Window
from pyspark.sql import functions as F
import sys

rdd = sc.parallelize([(1, 10), (3, 11), (1, 8), (1, 12), (3, 7), (3, 9)])
df = spark.createDataFrame(rdd, ['x', 'y'])
df.show()

df_g = df.groupBy('x').agg(F.collect_list('y').alias('y'))
df_g = df_g.withColumn('y_sorted', F.sort_array('y', asc = False))
df_g.withColumn('y_slice', F.slice(df_g.y_sorted, 1, 2)).show()

输出

+---+-----------+-----------+--------+
|  x|          y|   y_sorted| y_slice|
+---+-----------+-----------+--------+
|  1|[10, 8, 12]|[12, 10, 8]|[12, 10]|
|  3| [11, 7, 9]| [11, 9, 7]| [11, 9]|
+---+-----------+-----------+--------+