如何根据 Pyspark 中聚合函数的条件对计数进行分组?

How to group by a count based on a condition over an aggregated function in Pyspark?

假设我构建了以下示例数据集:

import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from datetime import datetime

spark = SparkSession.builder\
    .config("spark.driver.memory", "10g")\
    .config('spark.sql.repl.eagerEval.enabled', True)\ # to display df in pretty HTML
    .getOrCreate()

df = spark.createDataFrame(
    [
        ("US", "US_SL_A", datetime(2022, 1, 1), 3.8),
        ("US", "US_SL_A", datetime(2022, 1, 2), 4.3),
        ("US", "US_SL_A", datetime(2022, 1, 3), 4.3),
        ("US", "US_SL_A", datetime(2022, 1, 4), 3.95),
        ("US", "US_SL_A", datetime(2022, 1, 5), 1.),
        ("US", "US_SL_B", datetime(2022, 1, 1), 4.3),
        ("US", "US_SL_B", datetime(2022, 1, 2), 3.8),
        ("US", "US_SL_B", datetime(2022, 1, 3), 9.),
        ("US", "US_SL_C", datetime(2022, 1, 1), 1.),
        ("ES", "ES_SL_A", datetime(2022, 1, 1), 4.2),
        ("ES", "ES_SL_A", datetime(2022, 1, 2), 1.),
        ("ES", "ES_SL_B", datetime(2022, 1, 1), 2.),
        ("FR", "FR_SL_A", datetime(2022, 1, 1), 2.),
    ],
    schema = ("country", "platform", "timestamp", "size")
)

>> df.show()
+-------+--------+-------------------+----+
|country|platform|          timestamp|size|
+-------+--------+-------------------+----+
|     US| US_SL_A|2022-01-01 00:00:00| 3.8|
|     US| US_SL_A|2022-01-02 00:00:00| 4.3|
|     US| US_SL_A|2022-01-03 00:00:00| 4.3|
|     US| US_SL_A|2022-01-04 00:00:00|3.95|
|     US| US_SL_A|2022-01-05 00:00:00| 1.0|
|     US| US_SL_B|2022-01-01 00:00:00| 4.3|
|     US| US_SL_B|2022-01-02 00:00:00| 3.8|
|     US| US_SL_B|2022-01-03 00:00:00| 9.0|
|     US| US_SL_C|2022-01-01 00:00:00| 1.0|
|     ES| ES_SL_A|2022-01-01 00:00:00| 4.2|
|     ES| ES_SL_A|2022-01-02 00:00:00| 1.0|
|     ES| ES_SL_B|2022-01-01 00:00:00| 2.0|
|     FR| FR_SL_A|2022-01-01 00:00:00| 2.0|
+-------+--------+-------------------+----+

我的目标是检测尺寸列中异常值的数量,但之前是按国家和平台分组的。为此,我想使用四分位数范围作为标准;也就是说,我想计算所有那些值小于分位数 0.25 减去四分位数间距的 1.5 倍的尺寸。

我可以通过以下方式获得不同的分位数参数和每组所需的阈值:

>> df.groupBy(
    ["country", "platform"]
).agg(
    (
        F.round(1.5*(F.percentile_approx("size", 0.75) -  F.percentile_approx("size", 0.25)), 2)
    ).alias("1.5xInterquartile"),
    F.percentile_approx("size", 0.25).alias("q1"),
    F.percentile_approx("size", 0.75).alias("q3"),
)\
.withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`"))\ # Q1 - 1.5*IQR
.show()
+-------+--------+-----------------+---+---+---------+
|country|platform|1.5xInterquartile| q1| q3|threshold|
+-------+--------+-----------------+---+---+---------+
|     US| US_SL_A|             0.75|3.8|4.3|     3.05|
|     US| US_SL_B|              7.8|3.8|9.0|     -4.0|
|     US| US_SL_C|              0.0|1.0|1.0|      1.0|
|     ES| ES_SL_A|              4.8|1.0|4.2|     -3.8|
|     FR| FR_SL_A|              0.0|2.0|2.0|      2.0|
|     ES| ES_SL_B|              0.0|2.0|2.0|      2.0|
+-------+--------+-----------------+---+---+---------+

但这并不是我想要得到的。我想要的是,不是按四分位数聚合,而是按满足低于异常值阈值条件的每组行数的计数进行聚合。

期望的输出是这样的:

+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
|     US| US_SL_A|    1     |
|     US| US_SL_B|    0     |
|     US| US_SL_C|    0     |
|     ES| ES_SL_A|    0     |
|     FR| FR_SL_A|    0     |
|     ES| ES_SL_B|    0     |
+-------+--------+----------+

这是因为只有 (US, US_SL_A) 组有一个值 (1.) 低于此类组的离群值阈值

这是我实现该目标的尝试:

>> df.groupBy(
    ["country", "platform"]
).agg(
    (
        F.count(
            F.when(
                F.col("size") < F.percentile_approx("size", 0.25) - 1.5*(F.percentile_approx("size", 0.75) -  F.percentile_approx("size", 0.25)),
                True
            )
        )
    ).alias("n_outliers"),
)

但是我收到一个错误,其中指出:

AnalysisException: It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.;
Aggregate [country#0, platform#1], [country#0, platform#1, count(CASE WHEN (size#3 < (percentile_approx(size#3, 0.25, 10000, 0, 0) - ((percentile_approx(size#3, 0.75, 10000, 0, 0) - percentile_approx(size#3, 0.25, 10000, 0, 0)) * 1.5))) THEN true END) AS n_outliers#732L]
+- LogicalRDD [country#0, platform#1, timestamp#2, size#3], false

这里的关键是在聚合之前使用windows函数

import pyspark.sql.window as W

w = W.Window.partitionBy(["country", "platform"])

(df
 .withColumn("1.5xInterquartile", F.round(1.5*(F.percentile_approx("size", 0.75).over(w) -  F.percentile_approx("size", 0.25).over(w)), 2))
 .withColumn("q1",F.percentile_approx("size", 0.25).over(w))
 .withColumn("q3",F.percentile_approx("size", 0.75).over(w))
 .withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`")) # Q1 - 1.5*IQR
 .groupBy(["country", "platform"])
 .agg(F.count(F.when(F.col("size") < F.col("q1") - 1.5*(F.col("q3") - F.col("q1")), 1)).alias("n_outliers"))
 .show()
)

+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
|     ES| ES_SL_A|         0|
|     ES| ES_SL_B|         0|
|     FR| FR_SL_A|         0|
|     US| US_SL_A|         1|
|     US| US_SL_B|         0|
|     US| US_SL_C|         0|
+-------+--------+----------+

您的 countpercentile_approx 都需要聚合,但看起来上面的 agg 并没有处理这些。

您可以尝试对所有聚合使用 window 函数,这将为每条记录添加 n_outliers 计数。然后,稍后您可以使用 distinct 仅获取每组 1 条记录。

w = Window.partitionBy("country", "platform")

df = (df.withColumn('n_outliers', 
         F.count(F.when(
             F.col("size") < (F.percentile_approx("size", 0.25).over(w) - 1.5*(F.percentile_approx("size", 0.75).over(w) -  F.percentile_approx("size", 0.25).over(w))),
             1
         )).over(w))
     .select('country', 'platform', 'n_outliers')
     .distinct())