PySpark - 选择每个组中的所有行

PySpark - Selecting all rows within each group

我有一个类似于下面的数据框。

from datetime import date
rdd = sc.parallelize([
     [123,date(2007,1,31),1],
     [123,date(2007,2,28),1],
     [123,date(2007,3,31),1],
     [123,date(2007,4,30),1],
     [123,date(2007,5,31),1],
     [123,date(2007,6,30),1],
     [123,date(2007,7,31),1],
     [123,date(2007,8,31),1],
     [123,date(2007,8,31),2],
     [123,date(2007,9,30),1],
     [123,date(2007,9,30),2],
     [123,date(2007,10,31),1],
     [123,date(2007,10,31),2],
     [123,date(2007,11,30),1],
     [123,date(2007,11,30),2],
     [123,date(2007,12,31),1],
     [123,date(2007,12,31),2],
     [123,date(2007,12,31),3],
     [123,date(2008,1,31),1],
     [123,date(2008,1,31),2],
     [123,date(2008,1,31),3]
])

df = rdd.toDF(['id','sale_date','sale'])
df.show()

从上面的数据框中,我想将所有行保留为相对于日期的最新销售。所以基本上,我只会为每一行设置唯一的日期。在上面的例子中,输出看起来像:

rdd_out = sc.parallelize([
        [123,date(2007,1,31),1],
        [123,date(2007,2,28),1],
        [123,date(2007,3,31),1],
        [123,date(2007,4,30),1],
        [123,date(2007,5,31),1],
        [123,date(2007,6,30),1],
        [123,date(2007,7,31),1],
        [123,date(2007,8,31),2],
        [123,date(2007,9,30),2],
        [123,date(2007,10,31),2],
        [123,date(2007,11,30),2],
        [123,date(2007,12,31),2],
        [123,date(2008,1,31),3]
         ])

df_out = rdd_out.toDF(['id','sale_date','sale'])
df_out.show()

你能指导我如何得到这个结果吗?

仅供参考 - 使用 SAS,我会取得如下结果:

proc sort data = df; 
   by id date sale;
run;

data want; 
 set df;
 by id date sale;
 if last.date;
run;

可能有很多方法可以实现这一点,但一种方法是使用 Window。使用 Window,您可以将数据分区到一个或多个列(在您的情况下为 sale_date),最重要的是,您可以按特定列对每个分区内的数据进行排序(在您的情况下降序 sale,这样最新的销售排在第一位)。所以:

from pyspark.sql.window import Window
from pyspark.sql.functions import desc
my_window = Window.partitionBy("sale_date").orderBy(desc("sale"))

然后你可以做的是将这个 Window 应用到你的 DataFrame 上,并应用许多 Window-functions 中的一个。您可以应用的函数之一是 row_number,对于每个分区,根据您的 orderBy 向每一行添加一个行号。像这样:

from pyspark.sql.functions import row_number
df_out = df.withColumn("row_number",row_number().over(my_window))

这将导致每个日期的最后一次销售将有 row_number = 1。如果您随后过滤 row_number=1,您将获得每个组的最后一次销售。

所以,完整代码:

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, desc, col
my_window = Window.partitionBy("sale_date").orderBy(desc("sale"))
df_out = (
        df
        .withColumn("row_number",row_number().over(my_window))
        .filter(col("row_number") == 1)
        .drop("row_number")
    )

此处您希望将“部门”替换为 sale_date,将“薪水”替换为 sale

这是同一件事的 none window 示例...@Cleared 的回答非常好。这个答案在非常大的数据集上可能比使用 window 表现得更好。 Windows 根据我的经验,比使用 groupBy 的逻辑等效要慢。 (请随意测试哪种方法更适合您。)Windows 编写起来非常简单且易于理解,因此如果数据较小,可能是更好的选择。

from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkExample').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
unGroupedDf = df.select( \
  df["Department"], \
  f.struct(*[\ # Make a struct with all the record elements.
    df["Department"].alias("Dept"),\
    df["Salary"].alias("Salary"),\
    df["Name"].alias("Name")] )\
  .alias("record") )
unGroupedDf.groupBy("Department")\ #group
 .agg(f.collect_list("record")\  #Gather all the element in a group
  .alias("record"))\
  .select(\
    f.reverse(\ #Make the sort Descending
      f.array_sort(\ #Sort the array ascending
        f.col("record")\ #the struct
      )\
    )[0].alias("record"))\ #grab the "Max element in the array
    ).select( f.col("record.*") ).show() # use struct as Columns
  .show()

注意: 如果您没有指定带有 window 的 partitionBy,它将把所有数据发送到一个节点进行处理。这将是一个性能问题。