如何根据 Pyspark 中聚合函数的条件对计数进行分组?
How to group by a count based on a condition over an aggregated function in Pyspark?
假设我构建了以下示例数据集:
import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from datetime import datetime
spark = SparkSession.builder\
.config("spark.driver.memory", "10g")\
.config('spark.sql.repl.eagerEval.enabled', True)\ # to display df in pretty HTML
.getOrCreate()
df = spark.createDataFrame(
[
("US", "US_SL_A", datetime(2022, 1, 1), 3.8),
("US", "US_SL_A", datetime(2022, 1, 2), 4.3),
("US", "US_SL_A", datetime(2022, 1, 3), 4.3),
("US", "US_SL_A", datetime(2022, 1, 4), 3.95),
("US", "US_SL_A", datetime(2022, 1, 5), 1.),
("US", "US_SL_B", datetime(2022, 1, 1), 4.3),
("US", "US_SL_B", datetime(2022, 1, 2), 3.8),
("US", "US_SL_B", datetime(2022, 1, 3), 9.),
("US", "US_SL_C", datetime(2022, 1, 1), 1.),
("ES", "ES_SL_A", datetime(2022, 1, 1), 4.2),
("ES", "ES_SL_A", datetime(2022, 1, 2), 1.),
("ES", "ES_SL_B", datetime(2022, 1, 1), 2.),
("FR", "FR_SL_A", datetime(2022, 1, 1), 2.),
],
schema = ("country", "platform", "timestamp", "size")
)
>> df.show()
+-------+--------+-------------------+----+
|country|platform| timestamp|size|
+-------+--------+-------------------+----+
| US| US_SL_A|2022-01-01 00:00:00| 3.8|
| US| US_SL_A|2022-01-02 00:00:00| 4.3|
| US| US_SL_A|2022-01-03 00:00:00| 4.3|
| US| US_SL_A|2022-01-04 00:00:00|3.95|
| US| US_SL_A|2022-01-05 00:00:00| 1.0|
| US| US_SL_B|2022-01-01 00:00:00| 4.3|
| US| US_SL_B|2022-01-02 00:00:00| 3.8|
| US| US_SL_B|2022-01-03 00:00:00| 9.0|
| US| US_SL_C|2022-01-01 00:00:00| 1.0|
| ES| ES_SL_A|2022-01-01 00:00:00| 4.2|
| ES| ES_SL_A|2022-01-02 00:00:00| 1.0|
| ES| ES_SL_B|2022-01-01 00:00:00| 2.0|
| FR| FR_SL_A|2022-01-01 00:00:00| 2.0|
+-------+--------+-------------------+----+
我的目标是检测尺寸列中异常值的数量,但之前是按国家和平台分组的。为此,我想使用四分位数范围作为标准;也就是说,我想计算所有那些值小于分位数 0.25 减去四分位数间距的 1.5 倍的尺寸。
我可以通过以下方式获得不同的分位数参数和每组所需的阈值:
>> df.groupBy(
["country", "platform"]
).agg(
(
F.round(1.5*(F.percentile_approx("size", 0.75) - F.percentile_approx("size", 0.25)), 2)
).alias("1.5xInterquartile"),
F.percentile_approx("size", 0.25).alias("q1"),
F.percentile_approx("size", 0.75).alias("q3"),
)\
.withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`"))\ # Q1 - 1.5*IQR
.show()
+-------+--------+-----------------+---+---+---------+
|country|platform|1.5xInterquartile| q1| q3|threshold|
+-------+--------+-----------------+---+---+---------+
| US| US_SL_A| 0.75|3.8|4.3| 3.05|
| US| US_SL_B| 7.8|3.8|9.0| -4.0|
| US| US_SL_C| 0.0|1.0|1.0| 1.0|
| ES| ES_SL_A| 4.8|1.0|4.2| -3.8|
| FR| FR_SL_A| 0.0|2.0|2.0| 2.0|
| ES| ES_SL_B| 0.0|2.0|2.0| 2.0|
+-------+--------+-----------------+---+---+---------+
但这并不是我想要得到的。我想要的是,不是按四分位数聚合,而是按满足低于异常值阈值条件的每组行数的计数进行聚合。
期望的输出是这样的:
+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
| US| US_SL_A| 1 |
| US| US_SL_B| 0 |
| US| US_SL_C| 0 |
| ES| ES_SL_A| 0 |
| FR| FR_SL_A| 0 |
| ES| ES_SL_B| 0 |
+-------+--------+----------+
这是因为只有 (US, US_SL_A)
组有一个值 (1.) 低于此类组的离群值阈值
这是我实现该目标的尝试:
>> df.groupBy(
["country", "platform"]
).agg(
(
F.count(
F.when(
F.col("size") < F.percentile_approx("size", 0.25) - 1.5*(F.percentile_approx("size", 0.75) - F.percentile_approx("size", 0.25)),
True
)
)
).alias("n_outliers"),
)
但是我收到一个错误,其中指出:
AnalysisException: It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.;
Aggregate [country#0, platform#1], [country#0, platform#1, count(CASE WHEN (size#3 < (percentile_approx(size#3, 0.25, 10000, 0, 0) - ((percentile_approx(size#3, 0.75, 10000, 0, 0) - percentile_approx(size#3, 0.25, 10000, 0, 0)) * 1.5))) THEN true END) AS n_outliers#732L]
+- LogicalRDD [country#0, platform#1, timestamp#2, size#3], false
这里的关键是在聚合之前使用windows函数
import pyspark.sql.window as W
w = W.Window.partitionBy(["country", "platform"])
(df
.withColumn("1.5xInterquartile", F.round(1.5*(F.percentile_approx("size", 0.75).over(w) - F.percentile_approx("size", 0.25).over(w)), 2))
.withColumn("q1",F.percentile_approx("size", 0.25).over(w))
.withColumn("q3",F.percentile_approx("size", 0.75).over(w))
.withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`")) # Q1 - 1.5*IQR
.groupBy(["country", "platform"])
.agg(F.count(F.when(F.col("size") < F.col("q1") - 1.5*(F.col("q3") - F.col("q1")), 1)).alias("n_outliers"))
.show()
)
+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
| ES| ES_SL_A| 0|
| ES| ES_SL_B| 0|
| FR| FR_SL_A| 0|
| US| US_SL_A| 1|
| US| US_SL_B| 0|
| US| US_SL_C| 0|
+-------+--------+----------+
您的 count
和 percentile_approx
都需要聚合,但看起来上面的 agg
并没有处理这些。
您可以尝试对所有聚合使用 window 函数,这将为每条记录添加 n_outliers
计数。然后,稍后您可以使用 distinct
仅获取每组 1 条记录。
w = Window.partitionBy("country", "platform")
df = (df.withColumn('n_outliers',
F.count(F.when(
F.col("size") < (F.percentile_approx("size", 0.25).over(w) - 1.5*(F.percentile_approx("size", 0.75).over(w) - F.percentile_approx("size", 0.25).over(w))),
1
)).over(w))
.select('country', 'platform', 'n_outliers')
.distinct())
假设我构建了以下示例数据集:
import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from datetime import datetime
spark = SparkSession.builder\
.config("spark.driver.memory", "10g")\
.config('spark.sql.repl.eagerEval.enabled', True)\ # to display df in pretty HTML
.getOrCreate()
df = spark.createDataFrame(
[
("US", "US_SL_A", datetime(2022, 1, 1), 3.8),
("US", "US_SL_A", datetime(2022, 1, 2), 4.3),
("US", "US_SL_A", datetime(2022, 1, 3), 4.3),
("US", "US_SL_A", datetime(2022, 1, 4), 3.95),
("US", "US_SL_A", datetime(2022, 1, 5), 1.),
("US", "US_SL_B", datetime(2022, 1, 1), 4.3),
("US", "US_SL_B", datetime(2022, 1, 2), 3.8),
("US", "US_SL_B", datetime(2022, 1, 3), 9.),
("US", "US_SL_C", datetime(2022, 1, 1), 1.),
("ES", "ES_SL_A", datetime(2022, 1, 1), 4.2),
("ES", "ES_SL_A", datetime(2022, 1, 2), 1.),
("ES", "ES_SL_B", datetime(2022, 1, 1), 2.),
("FR", "FR_SL_A", datetime(2022, 1, 1), 2.),
],
schema = ("country", "platform", "timestamp", "size")
)
>> df.show()
+-------+--------+-------------------+----+
|country|platform| timestamp|size|
+-------+--------+-------------------+----+
| US| US_SL_A|2022-01-01 00:00:00| 3.8|
| US| US_SL_A|2022-01-02 00:00:00| 4.3|
| US| US_SL_A|2022-01-03 00:00:00| 4.3|
| US| US_SL_A|2022-01-04 00:00:00|3.95|
| US| US_SL_A|2022-01-05 00:00:00| 1.0|
| US| US_SL_B|2022-01-01 00:00:00| 4.3|
| US| US_SL_B|2022-01-02 00:00:00| 3.8|
| US| US_SL_B|2022-01-03 00:00:00| 9.0|
| US| US_SL_C|2022-01-01 00:00:00| 1.0|
| ES| ES_SL_A|2022-01-01 00:00:00| 4.2|
| ES| ES_SL_A|2022-01-02 00:00:00| 1.0|
| ES| ES_SL_B|2022-01-01 00:00:00| 2.0|
| FR| FR_SL_A|2022-01-01 00:00:00| 2.0|
+-------+--------+-------------------+----+
我的目标是检测尺寸列中异常值的数量,但之前是按国家和平台分组的。为此,我想使用四分位数范围作为标准;也就是说,我想计算所有那些值小于分位数 0.25 减去四分位数间距的 1.5 倍的尺寸。
我可以通过以下方式获得不同的分位数参数和每组所需的阈值:
>> df.groupBy(
["country", "platform"]
).agg(
(
F.round(1.5*(F.percentile_approx("size", 0.75) - F.percentile_approx("size", 0.25)), 2)
).alias("1.5xInterquartile"),
F.percentile_approx("size", 0.25).alias("q1"),
F.percentile_approx("size", 0.75).alias("q3"),
)\
.withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`"))\ # Q1 - 1.5*IQR
.show()
+-------+--------+-----------------+---+---+---------+
|country|platform|1.5xInterquartile| q1| q3|threshold|
+-------+--------+-----------------+---+---+---------+
| US| US_SL_A| 0.75|3.8|4.3| 3.05|
| US| US_SL_B| 7.8|3.8|9.0| -4.0|
| US| US_SL_C| 0.0|1.0|1.0| 1.0|
| ES| ES_SL_A| 4.8|1.0|4.2| -3.8|
| FR| FR_SL_A| 0.0|2.0|2.0| 2.0|
| ES| ES_SL_B| 0.0|2.0|2.0| 2.0|
+-------+--------+-----------------+---+---+---------+
但这并不是我想要得到的。我想要的是,不是按四分位数聚合,而是按满足低于异常值阈值条件的每组行数的计数进行聚合。
期望的输出是这样的:
+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
| US| US_SL_A| 1 |
| US| US_SL_B| 0 |
| US| US_SL_C| 0 |
| ES| ES_SL_A| 0 |
| FR| FR_SL_A| 0 |
| ES| ES_SL_B| 0 |
+-------+--------+----------+
这是因为只有 (US, US_SL_A)
组有一个值 (1.) 低于此类组的离群值阈值
这是我实现该目标的尝试:
>> df.groupBy(
["country", "platform"]
).agg(
(
F.count(
F.when(
F.col("size") < F.percentile_approx("size", 0.25) - 1.5*(F.percentile_approx("size", 0.75) - F.percentile_approx("size", 0.25)),
True
)
)
).alias("n_outliers"),
)
但是我收到一个错误,其中指出:
AnalysisException: It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.;
Aggregate [country#0, platform#1], [country#0, platform#1, count(CASE WHEN (size#3 < (percentile_approx(size#3, 0.25, 10000, 0, 0) - ((percentile_approx(size#3, 0.75, 10000, 0, 0) - percentile_approx(size#3, 0.25, 10000, 0, 0)) * 1.5))) THEN true END) AS n_outliers#732L]
+- LogicalRDD [country#0, platform#1, timestamp#2, size#3], false
这里的关键是在聚合之前使用windows函数
import pyspark.sql.window as W
w = W.Window.partitionBy(["country", "platform"])
(df
.withColumn("1.5xInterquartile", F.round(1.5*(F.percentile_approx("size", 0.75).over(w) - F.percentile_approx("size", 0.25).over(w)), 2))
.withColumn("q1",F.percentile_approx("size", 0.25).over(w))
.withColumn("q3",F.percentile_approx("size", 0.75).over(w))
.withColumn("threshold", F.col("q1") - F.col("`1.5xInterquartile`")) # Q1 - 1.5*IQR
.groupBy(["country", "platform"])
.agg(F.count(F.when(F.col("size") < F.col("q1") - 1.5*(F.col("q3") - F.col("q1")), 1)).alias("n_outliers"))
.show()
)
+-------+--------+----------+
|country|platform|n_outliers|
+-------+--------+----------+
| ES| ES_SL_A| 0|
| ES| ES_SL_B| 0|
| FR| FR_SL_A| 0|
| US| US_SL_A| 1|
| US| US_SL_B| 0|
| US| US_SL_C| 0|
+-------+--------+----------+
您的 count
和 percentile_approx
都需要聚合,但看起来上面的 agg
并没有处理这些。
您可以尝试对所有聚合使用 window 函数,这将为每条记录添加 n_outliers
计数。然后,稍后您可以使用 distinct
仅获取每组 1 条记录。
w = Window.partitionBy("country", "platform")
df = (df.withColumn('n_outliers',
F.count(F.when(
F.col("size") < (F.percentile_approx("size", 0.25).over(w) - 1.5*(F.percentile_approx("size", 0.75).over(w) - F.percentile_approx("size", 0.25).over(w))),
1
)).over(w))
.select('country', 'platform', 'n_outliers')
.distinct())