PySpark - GroupBy 和具有多个条件的聚合

PySpark - GroupBy and aggregation with multiple conditions

我想根据多种条件对数据进行分组和聚合。数据框包含产品 ID、故障代码、日期和故障类型。在这里,我准备了一个示例数据框:

from pyspark.sql.types import StructType,StructField, StringType, IntegerType, DateType
from datetime import datetime, date

data  = [("prod_001","fault_01",date(2020, 6, 4),"minor"),
         ("prod_001","fault_03",date(2020, 7, 2),"minor"),
         ("prod_001","fault_09",date(2020, 7, 14),"minor"),
         ("prod_001","fault_01",date(2020, 7, 14),"minor"),
         ("prod_001",None,date(2021, 4, 6),"major"),
         ("prod_001","fault_02",date(2021, 6, 22),"minor"),
         ("prod_001","fault_09",date(2021, 8, 1),"minor"),
         
         ("prod_002","fault_01",date(2020, 6, 13),"minor"),
         ("prod_002","fault_05",date(2020, 7, 11),"minor"),
         ("prod_002",None,date(2020, 8, 1),"major"),
         ("prod_002","fault_01",date(2021, 4, 15),"minor"),
         ("prod_002","fault_02",date(2021, 5, 11),"minor"),
         ("prod_002","fault_03",date(2021, 5, 13),"minor"),
  ]

schema = StructType([ \
    StructField("product_id",StringType(),True), \
    StructField("fault_code",StringType(),True), \
    StructField("date",DateType(),True), \
    StructField("fault_type", StringType(), True), \
  ])
 
df = spark.createDataFrame(data=data,schema=schema)
display(df)

总的来说,我想根据 product_id 进行分组,然后对日期进行 fault_codes (列表)的聚合。这里的一些特色是持续聚合到列表,直到 fault_type 从次要变为主要。在这种情况下,主要标记行将采用聚合的最后状态(见屏幕截图)。在一个 product_id 中,列表的聚合应该从新开始(以下 fault_code 被标记为次要)。

see target output here

在其他一些帖子中,我发现了以下我已经尝试过的代码片段。不幸的是,我还没有在所有条件下进行完整聚合。

df.sort("product_id", "date").groupby("product_id", "date").agg(F.collect_list("fault_code"))

编辑:

Window.partitionBy() 更接近了一点,但是一旦 fault_type 使用以下代码更改为 major 后仍然无法从新启动 collect_list()

df_test = df.sort("product_id", "date").groupby("product_id", "date", "fault_type").agg(F.collect_list("fault_code")).withColumnRenamed('collect_list(fault_code)', 'fault_code_list')

window_function = Window.partitionBy("product_id").rangeBetween(Window.unboundedPreceding, Window.currentRow).orderBy("date")

df_test = df_test.withColumn("new_version_v2", F.collect_list("fault_code_list").over(Window.partitionBy("product_id").orderBy("date"))) \
                 .withColumn("new_version_v2", F.flatten("new_version_v2"))

有人知道怎么做吗?

一种可能的方法是使用 Pandas UDF 和 applyInPandas

定义一个“普通”Python函数
  • 输入是一个 Pandas 数据帧,输出是另一个数据帧。
  • 数据帧的大小无关紧要
def grp(df):
    df['a'] = 'AAA'
    df = df[df['fault_code'] == 'fault_01']
    return df[['product_id', 'a']]
使用实际 Pandas 数据帧测试此函数
  • 唯一要记住的是这个数据框只是你实际数据框的一个子集
grp(df.where('product_id == "prod_001"').toPandas())

    product_id  a
0   prod_001    AAA
3   prod_001    AAA
使用 applyInPandas
将此函数应用到 Spark 数据帧中
(df
    .groupBy('product_id')
    .applyInPandas(grp, schema)
    .show()
)

                                                                                
+----------+---+
|product_id|  a|
+----------+---+
|  prod_001|AAA|
|  prod_001|AAA|
|  prod_002|AAA|
|  prod_002|AAA|
|  prod_002|AAA|
|  prod_002|AAA|
|  prod_002|AAA|
+----------+---+

您的修改已结束。这并不那么简单,我只是想出了一个可行但不那么简洁的解决方案。

lagw = Window.partitionBy('product_id').orderBy('date')
grpw = Window.partitionBy(['product_id', 'grp']).orderBy('date').rowsBetween(Window.unboundedPreceding, 0)

df = (df.withColumn('grp', F.sum(
        (F.lag('fault_type').over(lagw).isNull()
        | (F.lag('fault_type').over(lagw) == 'major')
     ).cast('int')).over(lagw))
     .withColumn('fault_code', F.collect_list('fault_code').over(grpw)))

df.orderBy(['product_id', 'grp']).show()
# +----------+----------------------------------------+----------+----------+---+
# |product_id|                              fault_code|      date|fault_type|grp|
# +----------+----------------------------------------+----------+----------+---+
# |  prod_001|[fault_01]                              |2020-06-04|     minor|  1|
# |  prod_001|[fault_01, fault_03]                    |2020-07-02|     minor|  1|
# |  prod_001|[fault_01, fault_03, fault_09]          |2020-07-14|     minor|  1|
# |  prod_001|[fault_01, fault_03, fault_09, fault_01]|2020-07-14|     minor|  1|
# |  prod_001|[fault_01, fault_03, fault_09, fault_01]|2021-04-06|     major|  1|
# |  prod_001|[fault_02]                              |2021-06-22|     minor|  2|
# |  prod_001|[fault_02, fault_09]                    |2021-08-01|     minor|  2|
# |  prod_002|[fault_01]                              |2020-06-13|     minor|  1|
# |  prod_002|[fault_01, fault_02]                    |2020-07-11|     minor|  1|
...

解释:

首先,我创建 grp 列来对连续的“次要”+“主要”进行分类。我使用 sumlag 来查看前一行是否为“主要”,然后我递增,否则,我保持与前一行相同的值。

# If cond is True, sum 1, if False, sum 0.
F.sum((cond).cast('int'))
df.orderBy(['product_id', 'date']).select('product_id', 'date', 'fault_type', 'grp').show()

+----------+----------+----------+---+
|product_id|      date|fault_type|grp|
+----------+----------+----------+---+
|  prod_001|2020-06-04|     minor|  1|
|  prod_001|2020-07-02|     minor|  1|
|  prod_001|2020-07-14|     minor|  1|
|  prod_001|2020-07-14|     minor|  1|
|  prod_001|2021-04-06|     major|  1|
|  prod_001|2021-06-22|     minor|  2|
|  prod_001|2021-08-01|     minor|  2|
|  prod_002|2020-06-13|     minor|  1|
|  prod_002|2020-07-11|     minor|  1|
...

生成此 grp 后,我可以按 product_idgrp 进行分区以应用 collect_list