使用pyspark有条件地聚合列?

Aggregating columns conditionally with pyspark?

我有以下数据集。我想对所有变量进行分组并根据以下条件拆分数据。

但是,当我尝试下面的代码时出现错误。

CUST_ID NAME    GENDER  AGE
id_01   MONEY   F   43
id_02   BAKER   F   32
id_03   VOICE   M   31
id_04   TIME    M   56
id_05   TIME    F   24
id_06   TALENT  F   28
id_07   ISLAND  F   21
id_08   ISLAND  F   27
id_09   TUME    F   24
id_10   TIME    F   75
id_11   SKY M   35
id_12   VOICE   M   70



    from pyspark.sql.functions import *

    df.groupBy("CUST_ID", "NAME", "GENDER", "AGE").agg(
       CUST_ID.count AS TOTAL
       SUM(WHEN ((AGE >= 18 AND AGE <= 34) AND GENDER = 'M') THEN COUNT(CUST_ID) ELSE 0 END AS "M18-34")
       SUM(WHEN ((AGE >= 18 AND AGE <= 34) AND GENDER = 'F') THEN COUNT(CUST_ID) ELSE 0 END AS "F18-34")
       SUM(WHEN ((AGE >= 18 AND AGE <= 34 THEN COUNT(CUST_ID) ELSE 0 END AS "18-34")
       SUM(WHEN ((AGE >= 25 AND AGE <= 54 THEN COUNT(CUST_ID) ELSE 0 END AS "25-54")
       SUM(WHEN ((AGE >= 25 AND AGE <= 54) AND GENDER = 'F') THEN COUNT(CUST_ID) ELSE 0 END AS "F25-54")
       SUM(WHEN ((AGE >= 25 AND AGE <= 54) AND GENDER = 'M') THEN COUNT(CUST_ID) ELSE 0 END AS "M25-54")   
    )

我会很感激你的help/suggestions

提前致谢

您的代码既不是有效的 pyspark 也不是有效的 Spark SQL。有很多语法问题。我试图在下面修复它们,不确定这是否是您想要的。如果你有这么多类似 SQL 的语句,最好直接使用 Spark SQL 而不是 pyspark API:

df.createOrReplaceTempView('df')
result = spark.sql("""
SELECT NAME,
       COUNT(CUST_ID) AS TOTAL,
       SUM(CASE WHEN ((AGE >= 18 AND AGE <= 34) AND GENDER = 'M') THEN 1 ELSE 0 END) AS `M18-34`,
       SUM(CASE WHEN ((AGE >= 18 AND AGE <= 34) AND GENDER = 'F') THEN 1 ELSE 0 END) AS `F18-34`,
       SUM(CASE WHEN (AGE >= 18 AND AGE <= 34) THEN 1 ELSE 0 END) AS `18-34`,
       SUM(CASE WHEN (AGE >= 25 AND AGE <= 54) THEN 1 ELSE 0 END) AS `25-54`,
       SUM(CASE WHEN ((AGE >= 25 AND AGE <= 54) AND GENDER = 'F') THEN 1 ELSE 0 END) AS `F25-54`,
       SUM(CASE WHEN ((AGE >= 25 AND AGE <= 54) AND GENDER = 'M') THEN 1 ELSE 0 END) AS `M25-54` 
FROM df
GROUP BY NAME
""")

result.show()
+------+-----+------+------+-----+-----+------+------+
|  NAME|TOTAL|M18-34|F18-34|18-34|25-54|F25-54|M25-54|
+------+-----+------+------+-----+-----+------+------+
|ISLAND|    2|     0|     2|    2|    1|     1|     0|
| MONEY|    1|     0|     0|    0|    1|     1|     0|
|  TIME|    3|     0|     1|    1|    0|     0|     0|
| VOICE|    2|     1|     0|    1|    1|     0|     1|
|  TUME|    1|     0|     1|    1|    0|     0|     0|
| BAKER|    1|     0|     1|    1|    1|     1|     0|
|TALENT|    1|     0|     1|    1|    1|     1|     0|
|   SKY|    1|     0|     0|    0|    1|     0|     1|
+------+-----+------+------+-----+-----+------+------+

如果您需要 pyspark 解决方案,这里是第一列的示例。剩下的你可以直接算出来。

import pyspark.sql.functions as F
result = df.groupBy('Name').agg(
    F.count('CUST_ID').alias('TOTAL'),
    F.count(F.when(F.expr("(AGE >= 18 AND AGE <= 34) AND GENDER = 'M'"), 1)).alias("M18-34")
)

result.show()
+------+-----+------+
|  Name|TOTAL|M18-34|
+------+-----+------+
|ISLAND|    2|     0|
| MONEY|    1|     0|
|  TIME|    3|     0|
| VOICE|    2|     1|
|  TUME|    1|     0|
| BAKER|    1|     0|
|TALENT|    1|     0|
|   SKY|    1|     0|
+------+-----+------+