过滤然后计算许多不同的阈值

filter then count for many different threshold

我想计算在非常大的数据帧上满足条件的行数,可以通过

实现
df.filter(col("value") >= thresh).count()

我想知道 [1, 10] 范围内每个阈值的结果。枚举每个阈值然后执行此操作将扫描数据帧 10 次。很慢。

如果只扫描一次df就可以实现?

使用带有 when 表达式的条件聚合应该可以完成这项工作。

这是一个例子:

from pyspark.sql import functions as F

df = spark.createDataFrame([(1,), (2,), (3,), (4,), (4,), (6,), (7,)], ["value"])

count_expr = [
    F.count(F.when(F.col("value") >= th, 1)).alias(f"gte_{th}")
    for th in range(1, 11)
]

df.select(*count_expr).show()
#+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
#|gte_1|gte_2|gte_3|gte_4|gte_5|gte_6|gte_7|gte_8|gte_9|gte_10|
#+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+
#|    7|    6|    5|    4|    2|    2|    1|    0|    0|     0|
#+-----+-----+-----+-----+-----+-----+-----+-----+-----+------+

为每个阈值创建一个指标列,然后求和:

import random
import pyspark.sql.functions as F
from pyspark.sql import Row

df = spark.createDataFrame([Row(value=random.randint(0,10)) for _ in range(1_000_000)])

df.select([
    (F.col("value") >= thresh)
    .cast("int")
    .alias(f"ind_{thresh}") 
    for thresh in range(1,11)
]).groupBy().sum().show()

# +----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+
# |sum(ind_1)|sum(ind_2)|sum(ind_3)|sum(ind_4)|sum(ind_5)|sum(ind_6)|sum(ind_7)|sum(ind_8)|sum(ind_9)|sum(ind_10)|
# +----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+
# |    908971|    818171|    727240|    636334|    545463|    454279|    363143|    272460|    181729|      90965|
# +----------+----------+----------+----------+----------+----------+----------+----------+----------+-----------+

使用 user-defined 函数 udf 来自 pyspark.sql.functions:

import pandas as pd
import numpy as np

df = pd.DataFrame(np.random.randint(0,100, size=(20)), columns=['val'])
thres =  [90, 80, 30]     # these are the thresholds
thres.sort(reverse=True)  # list needs to be sorted

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
spark = SparkSession.builder \
    .master("local[2]") \
    .appName("myApp") \
    .getOrCreate()
sparkDF = spark.createDataFrame(df)

myUdf = udf(lambda x: 0 if x>thres[0] else 1 if x>thres[1] else 2 if  x>thres[2] else 3)
sparkDF = sparkDF.withColumn("rank", myUdf(sparkDF.val))
sparkDF.show()
# +---+----+                                                                      
# |val|rank|
# +---+----+
# | 28|   3|
# | 54|   2|
# | 19|   3|
# |  4|   3|
# | 74|   2|
# | 62|   2|
# | 95|   0|
# | 19|   3|
# | 55|   2|
# | 62|   2|
# | 33|   2|
# | 93|   0|
# | 81|   1|
# | 41|   2|
# | 80|   2|
# | 53|   2|
# | 14|   3|
# | 16|   3|
# | 30|   3|
# | 77|   2|
# +---+----+
sparkDF.groupby(['rank']).count().show()
# Out: 
# +----+-----+
# |rank|count|
# +----+-----+
# |   3|    7|
# |   0|    2|
# |   1|    1|
# |   2|   10|
# +----+-----+

如果一个值严格大于 thres[i] 但小于或等于 thres[i-1],则该值获得排名 i。这应该最大限度地减少比较次数。

对于 thres = [90, 80, 30],我们有等级 0-> [max, 90[、1-> [90, 80[、2->[80, 30[、3->[30, min]