在新数据框中对来自 pyspark 数据框的子字段进行分组并按链接列排序

grouping sub-fields from pyspark dataframe in new dataframe and sorting by linked column

我有一个 pyspark 数据框,其中有一列包含不同类型的电影。它看起来像这样:

|Movie Name| Genres | Review |
|X         | Y, Z   | 10     |

我需要根据用户评论找到前 N 个类型,这是每部电影的一列。我已经将流派列分解为它自己的数据框,其中包含这样的评论列:

splitDf = df.withColumn("genre", explode(split(col("genre"), "[,]")))

这样可以将每个类型与每个电影列表分开。但是现在我需要通过评论对它们进行排名,并且每个不同的类型(来自原始电影 df 中的每一行)都剩下重复的行。我试过了

specifiedDf = splitDf.select("genre","user_review").groupBy("genre").avg("user_review")

我试过调整 table,但似乎没有任何东西可以将这些类型组合在一起,所以我可以对评论进行平均。

按照给出的建议,我可以在 pandas 中使用

splitDf= df.to_pandas_on_spark()
splitDf['genre'] = splitDf['genre'].str.split(',\s*')
resultDf = result.explode('genre')[['genre','user_review']].groupby('genre').agg("avg") 
resultDf = resultDf.sort_values(by="user_review", ascending=False)

但是我仍然无法将其转换为 pyspark,这是我主要修改的代码

splitArrayDf = df.select(split('genre', ',').alias("genre"),"user_review")    
splitArrayDf = splitArrayDf.select(explode("genre").alias("genre"),"user_review") /
.groupBy("genre").agg({"user_review":"avg"})

这会创建重复的流派字段,而 pandas 不会。

我建议用逗号分隔 Genres,但将输出分配给同一列。然后你可以 explode 该列并做一个 groupby 来计算每个流派的 Review 的总和:

import pandas as pd

data = '''Movie Name| Genres | Review
X         | Y, Z   | 10
Y         | W, Z   | 7'''

df = pd.read_csv(io.StringIO(data),sep='\s*\|\s*')
df['Genres'] = df['Genres'].str.split(',\s*')
result = df.explode('Genres')[['Genres', 'Review']].groupby('Genres').agg(sum)

输出:

Genres Review
W 7
Y 10
Z 17

根据您在问题中给出的示例,您可能会得到重复项 genre,因为在逗号分隔符 ,.[=14= 之前的 and/or 之后存在空格]

要处理这个问题,您可以在拆分之前用空字符串替换它们,或者简单地使用正则表达式拆分 \s*,\s*:

import pyspark.sql.functions as F

data = [("X", "Y, Z", 10), ("Y", "Z, W", 7)]
df = spark.createDataFrame(data, ["movie_name", "genre", "user_review"])

df1 = df.withColumn(
    "genre",
    F.explode(F.split("genre", r"\s*,\s*"))
).groupBy("genre").agg(
    F.avg("user_review").alias("user_review")
)

df1.show()
#+-----+-----------+
#|genre|user_review|
#+-----+-----------+
#|    Y|       10.0|
#|    Z|        8.5|
#|    W|        7.0|
#+-----+-----------+