groupby 之外的列的 pyspark collect_set

pyspark collect_set of column outside of groupby

我正在尝试使用 collect_set 获取 categorie_names 的字符串列表,这些字符串 NOT 是 groupby 的一部分。 我的密码是

from pyspark import SparkContext
from pyspark.sql import HiveContext
from pyspark.sql import functions as F

sc = SparkContext("local")
sqlContext = HiveContext(sc)
df = sqlContext.createDataFrame([
     ("1", "cat1", "Dept1", "product1", 7),
     ("2", "cat2", "Dept1", "product1", 100),
     ("3", "cat2", "Dept1", "product2", 3),
     ("4", "cat1", "Dept2", "product3", 5),
    ], ["id", "category_name", "department_id", "product_id", "value"])

df.show()
df.groupby("department_id", "product_id")\
    .agg({'value': 'sum'}) \
    .show()

#            .agg( F.collect_set("category_name"))\

输出为

+---+-------------+-------------+----------+-----+
| id|category_name|department_id|product_id|value|
+---+-------------+-------------+----------+-----+
|  1|         cat1|        Dept1|  product1|    7|
|  2|         cat2|        Dept1|  product1|  100|
|  3|         cat2|        Dept1|  product2|    3|
|  4|         cat1|        Dept2|  product3|    5|
+---+-------------+-------------+----------+-----+

+-------------+----------+----------+
|department_id|product_id|sum(value)|
+-------------+----------+----------+
|        Dept1|  product2|         3|
|        Dept1|  product1|       107|
|        Dept2|  product3|         5|
+-------------+----------+----------+

我想要这个输出

+-------------+----------+----------+----------------------------+
|department_id|product_id|sum(value)| collect_list(category_name)|
+-------------+----------+----------+----------------------------+
|        Dept1|  product2|         3|  cat2                      |
|        Dept1|  product1|       107|  cat1, cat2                |
|        Dept2|  product3|         5|  cat1                      |
+-------------+----------+----------+----------------------------+

尝试 1

df.groupby("department_id", "product_id")\
    .agg({'value': 'sum'}) \
    .agg(F.collect_set("category_name")) \
    .show()

我收到这个错误:

pyspark.sql.utils.AnalysisException: "cannot resolve 'category_name' given input columns: [department_id, product_id, sum(value)];;\n'Aggregate [collect_set('category_name, 0, 0) AS collect_set(category_name)#35]\n+- Aggregate [department_id#2, product_id#3], [department_id#2, product_id#3, sum(value#4L) AS sum(value)#24L]\n +- LogicalRDD [id#0, category_name#1, department_id#2, product_id#3, value#4L]\n"

尝试 2 我把 category_name 作为 groupby

的一部分
df.groupby("category_name", "department_id", "product_id")\
    .agg({'value': 'sum'}) \
    .agg(F.collect_set("category_name")) \
    .show()

有效但输出不正确

+--------------------------+
|collect_set(category_name)|
+--------------------------+
|              [cat1, cat2]|
+--------------------------+

你可以。您的案例的正确语法是:

df.groupby("department_id", "product_id")\
    .agg(F.sum('value'), F.collect_set("category_name"))\
    .show()
#+-------------+----------+----------+--------------------------+
#|department_id|product_id|sum(value)|collect_set(category_name)|
#+-------------+----------+----------+--------------------------+
#|        Dept1|  product2|         3|                    [cat2]|
#|        Dept1|  product1|       107|              [cat1, cat2]|
#|        Dept2|  product3|         5|                    [cat1]|
#+-------------+----------+----------+--------------------------+

您的方法不起作用,因为第一个 .agg() 适用于 pyspark.sql.group.GroupedData 和 returns 一个新的 DataFrame。随后调用 agg 实际上是 pyspark.sql.DataFrame.agg

shorthand for df.groupBy.agg()

所以实际上第二次调用 agg 是再次分组,这不是您想要的。