如何在 PySpark SQL when() 子句中使用聚合值?

How do you use aggregated values within PySpark SQL when() clause?

我正在尝试学习 PySpark,并尝试学习如何使用 SQL when() 子句来更好地分类我的数据。 (参见此处:https://sparkbyexamples.com/spark/spark-case-when-otherwise-example/)我似乎无法解决的是如何将实际标量值插入 when() 条件以明确进行比较。聚合函数似乎 return 比实际的 float() 类型更多的表格值。
我不断收到此错误消息 unsupported operand type(s) for -: 'method' and 'method'
当我尝试 运行 函数时聚合原始数据框中的另一列我注意到结果似乎不像 table (agg(select(f.stddev ("Col")) 给出的结果如下: "DataFrame[stddev_samp(TAXI_OUT): double]") 这是一个示例如果你想复制我想要完成的事情,我想知道你如何在 when() 子句中获得像标准差和平均值这样的聚合值,以便你可以使用它来对你的新列进行分类:

samp = spark.createDataFrame(
    [("A","A1",4,1.25),("B","B3",3,2.14),("C","C2",7,4.24),("A","A3",4,1.25),("B","B1",3,2.14),("C","C1",7,4.24)],
    ["Category","Sub-cat","quantity","cost"])
  
    psMean = samp.agg({'quantity':'mean'})
    psStDev = samp.agg({'quantity':'stddev'})

    psCatVect = samp.withColumn('quant_category',.when(samp['quantity']<=(psMean-psStDev),'small').otherwise('not small')) ```  

你例子中的psMean和psStdev是dataframes,你需要使用collect()方法来提取标量值

psMean = samp.agg({'quantity':'mean'}).collect()[0][0]
psStDev = samp.agg({'quantity':'stddev'}).collect()[0][0]

您还可以创建一个包含所有统计数据的变量作为 pandas DataFrame 并稍后在 pyspark 代码中引用它:

from pyspark.sql import functions as F

stats = (
    samp.select(
        F.mean("quantity").alias("mean"), 
        F.stddev("quantity").alias("std")
    ).toPandas()
)


(
    samp.withColumn('quant_category', 
                F.when(
                    samp['quantity'] <= stats["mean"].item() - stats["std"].item(), 
                    'small')
                .otherwise('not small')
               )
    .toPandas()
)