如何在 Spark 数据集上应用可定制的聚合器?

How to apply customizable Aggregator on Spark Dataset?

我有以下 spark 数据集的架构和学生记录。

id | name | subject | score
1  | Tom  | Math    | 99
1  | Tom  | Math    | 88
1  | Tom  | Physics | 77
2  | Amy  | Math    | 66

我的目标是将这个数据集转移到另一个显示所有学生每门学科最高分记录的列表

id | name | subject_score_list
1  | Tom  | [(Math, 99), (Physics, 77)]
2  | Amy  | [(Math, 66)]

我决定在将此数据集转换为 ((id, name), (subject score)) 键值对后使用 Aggregator 进行转换。

对于缓冲区,我尝试使用可变 Map[String, Integer],这样如果主题存在并且新分数更高,我可以更新分数。这是聚合器的样子

import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator

type StudentSubjectPair = ((String, String), (String, Integer))
type SubjectMap = collection.mutable.Map[String, Integer]
type SubjectList = List[(String, Integer)]

val StudentSubjectAggregator = new Aggregator[StudentSubjectPair, SubjectMap, SubjectList] {
  def zero: SubjectMap = collection.mutable.Map[String, Integer]()

  def reduce(buf: SubjectMap, input: StudentSubjectPair): SubjectMap = {
    if (buf.contains(input._2._1))
      buf.map{ case (input._2._1, score) => input._2._1 -> math.max(score, input._2._2) }
    else
      buf(input._2._1) = input._2._2
    buf
  }

  def merge(b1: SubjectMap, b2: SubjectMap): SubjectMap = {
    for ((subject, score) <- b2) {
      if (b1.contains(subject))
        b1(subject) = math.max(score, b2(subject))
      else
        b1(subject) = score
    }
    b1
  }

  def finish(buf: SubjectMap): SubjectList = buf.toList

  override def bufferEncoder: Encoder[SubjectMap] = ExpressionEncoder[SubjectMap]
  override def outputEncoder: Encoder[SubjectList] = ExpressionEncoder[SubjectList]
}.toColumn.name("subject_score_list")

我使用 Aggregator,因为我发现它可以自定义,如果我想找到一个科目的平均分数,我可以简单地更改 reducemerge 函数。 post.

我期待两个答案
  1. 使用 Aggregator 完成这项工作是不是一个好方法?还有其他简单的方法可以得到相同的输出吗?
  2. collection.mutable.Map[String, Integer]List[(String, Integer)] 的正确编码器是什么,因为我总是收到以下错误
org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 37.0 failed 1 times, most recent failure: Lost task 0.0 in stage 37.0 (TID 231, localhost, executor driver):
java.lang.ClassCastException: scala.collection.immutable.HashMap$HashTrieMap cannot be cast to scala.collection.mutable.Map
    at $anon.merge(<console>:54)

感谢任何意见和帮助,谢谢!

我认为您可以使用 DataFrame API 实现您想要的结果。

val df= Seq((1 ,"Tom" ,"Math",99),
    (1 ,"Tom" ,"Math" ,88),
    (1 ,"Tom" ,"Physics" ,77),
    (2 ,"Amy" ,"Math"  ,66)).toDF("id", "name", "subject","score")

GroupBy on id, name, and subject for max score, followed by a groupBy on 最高分 id,name with a collect_list on map of subject,score

df.groupBy("id","name", "subject").agg(max("score").as("score")).groupBy("id","name").
    agg(collect_list(map($"subject",$"score")).as("subject_score_list"))


+---+----+--------------------+
| id|name|  subject_score_list|
+---+----+--------------------+
|  1| Tom|[[Physics -> 77],...|
|  2| Amy|      [[Math -> 66]]|
+---+----+--------------------+