分区内多列上的 Spark 聚合,无需随机播放

Spark aggregate on multiple columns within partition without shuffle

我正在尝试在多列上聚合数据框。我知道聚合所需的一切都在分区内——也就是说,不需要洗牌,因为聚合的所有数据都是分区本地的。

拿一个 example,如果我有类似

        val sales=sc.parallelize(List(
        ("West",  "Apple",  2.0, 10),
        ("West",  "Apple",  3.0, 15),
        ("West",  "Orange", 5.0, 15),
        ("South", "Orange", 3.0, 9),
        ("South", "Orange", 6.0, 18),
        ("East",  "Milk",   5.0, 5))).repartition(2)
        val tdf = sales.map{ case (store, prod, amt, units) => ((store, prod), (amt, amt, amt, units)) }.
        reduceByKey((x, y) => (x._1 + y._1, math.min(x._2, y._2), math.max(x._3, y._3), x._4 + y._4))
      println(tdf.toDebugString)

我得到这样的结果

(2) ShuffledRDD[12] at reduceByKey at Test.scala:59 []
 +-(2) MapPartitionsRDD[11] at map at Test.scala:58 []
    |  MapPartitionsRDD[10] at repartition at Test.scala:57 []
    |  CoalescedRDD[9] at repartition at Test.scala:57 []
    |  ShuffledRDD[8] at repartition at Test.scala:57 []
    +-(1) MapPartitionsRDD[7] at repartition at Test.scala:57 []
       |  ParallelCollectionRDD[6] at parallelize at Test.scala:51 []

可以看到MapPartitionsRDD,不错。但是还有 ShuffleRDD,我想阻止它,因为我想要每个分区的汇总,按分区内的列值分组。

zero323's 非常接近,但我需要 "group by columns" 功能。

关于我上面的示例,我正在寻找

将产生的结果
select store, prod, sum(amt), avg(units) from sales group by partition_id, store, prod

(我真的不需要分区 ID - 这只是为了说明我想要每个分区的结果)

我看过 of 但我生成的每个调试字符串都有随机播放。我真的希望摆脱洗牌。我想我本质上是在寻找 groupByKeysWithinPartitions 函数。

实现这一目标的唯一方法是使用 mapPartitions 并使用自定义代码在迭代分区时对值进行分组和计算。 正如您提到的,数据已经按分组键(存储、生产)排序,我们可以以流水线方式有效地计算您的聚合:

(1) 定义助手 类:

:paste

case class MyRec(store: String, prod: String, amt: Double, units: Int)

case class MyResult(store: String, prod: String, total_amt: Double, min_amt: Double, max_amt: Double, total_units: Int)

object MyResult {
  def apply(rec: MyRec): MyResult = new MyResult(rec.store, rec.prod, rec.amt, rec.amt, rec.amt, rec.units)

  def aggregate(result: MyResult, rec: MyRec) = {
    new MyResult(result.store,
      result.prod,
      result.total_amt + rec.amt,
      math.min(result.min_amt, rec.amt),
      math.max(result.max_amt, rec.amt),
      result.total_units + rec.units
    )
  }
}

(2) 定义流水线聚合器:

:paste

def pipelinedAggregator(iter: Iterator[MyRec]): Iterator[Seq[MyResult]] = {

var prev: MyResult = null
var res: Seq[MyResult] = Nil

for (crt <- iter) yield {
  if (prev == null) {
    prev = MyResult(crt)
  }
  else if (prev.prod != crt.prod || prev.store != crt.store) {
    res = Seq(prev)
    prev = MyResult(crt)
  }
  else {
    prev = MyResult.aggregate(prev, crt)
  }

  if (!iter.hasNext) {
    res = res ++ Seq(prev)
  }

  res
}

}

(3) 运行 聚合:

:paste

val sales = sc.parallelize(
  List(MyRec("West", "Apple", 2.0, 10),
    MyRec("West", "Apple", 3.0, 15),
    MyRec("West", "Orange", 5.0, 15),
    MyRec("South", "Orange", 3.0, 9),
    MyRec("South", "Orange", 6.0, 18),
    MyRec("East", "Milk", 5.0, 5),
    MyRec("West", "Apple", 7.0, 11)), 2).toDS

sales.mapPartitions(iter => Iterator(iter.toList)).show(false)

val result = sales
  .mapPartitions(recIter => pipelinedAggregator(recIter))
  .flatMap(identity)

result.show
result.explain

输出:

    +-------------------------------------------------------------------------------------+
    |value                                                                                |
    +-------------------------------------------------------------------------------------+
    |[[West,Apple,2.0,10], [West,Apple,3.0,15], [West,Orange,5.0,15]]                     |
    |[[South,Orange,3.0,9], [South,Orange,6.0,18], [East,Milk,5.0,5], [West,Apple,7.0,11]]|
    +-------------------------------------------------------------------------------------+

    +-----+------+---------+-------+-------+-----------+
    |store|  prod|total_amt|min_amt|max_amt|total_units|
    +-----+------+---------+-------+-------+-----------+
    | West| Apple|      5.0|    2.0|    3.0|         25|
    | West|Orange|      5.0|    5.0|    5.0|         15|
    |South|Orange|      9.0|    3.0|    6.0|         27|
    | East|  Milk|      5.0|    5.0|    5.0|          5|
    | West| Apple|      7.0|    7.0|    7.0|         11|
    +-----+------+---------+-------+-------+-----------+

    == Physical Plan ==
    *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).store, true) AS store#31, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).prod, true) AS prod#32, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).total_amt AS total_amt#33, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).min_amt AS min_amt#34, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).max_amt AS max_amt#35, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).total_units AS total_units#36]
    +- MapPartitions <function1>, obj#30: $line14.$read$$iw$$iw$MyResult
       +- MapPartitions <function1>, obj#20: scala.collection.Seq
          +- Scan ExternalRDDScan[obj#4]
    sales: org.apache.spark.sql.Dataset[MyRec] = [store: string, prod: string ... 2 more fields]
    result: org.apache.spark.sql.Dataset[MyResult] = [store: string, prod: string ... 4 more fields]    

如果这是您要查找的输出

+-----+------+--------+----------+
|store|prod  |max(amt)|avg(units)|
+-----+------+--------+----------+
|South|Orange|6.0     |13.5      |
|West |Orange|5.0     |15.0      |
|East |Milk  |5.0     |5.0       |
|West |Apple |3.0     |12.5      |
+-----+------+--------+----------+

Spark Dataframe 具有您要求的所有功能,具有通用简洁 shorthand 语法

import org.apache.spark.sql._
import org.apache.spark.sql.functions._


object TestJob2 {

  def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._

val rawDf = Seq(
  ("West",  "Apple",  2.0, 10),
  ("West",  "Apple",  3.0, 15),
  ("West",  "Orange", 5.0, 15),
  ("South", "Orange", 3.0, 9),
  ("South", "Orange", 6.0, 18),
  ("East",  "Milk",   5.0, 5)
).toDF("store", "prod", "amt", "units")

rawDf.show(false)
rawDf.printSchema

val aggDf = rawDf
  .groupBy("store", "prod")
  .agg(
    max(col("amt")),
    avg(col("units"))
//        in case you need to retain more info
//        , collect_list(struct("*")).as("horizontal")
  )

aggDf.printSchema

aggDf.show(false)
  }
}

取消注释 collect_list 行以汇总所有内容

+-----+------+--------+----------+---------------------------------------------------+
|store|prod  |max(amt)|avg(units)|horizontal                                         
|
+-----+------+--------+----------+---------------------------------------------------+
|South|Orange|6.0     |13.5      |[[South, Orange, 3.0, 9], [South, Orange, 6.0, 18]]|
|West |Orange|5.0     |15.0      |[[West, Orange, 5.0, 15]]                          
|
|East |Milk  |5.0     |5.0       |[[East, Milk, 5.0, 5]]                             
|
|West |Apple |3.0     |12.5      |[[West, Apple, 2.0, 10], [West, Apple, 3.0, 15]]   |
+-----+------+--------+----------+---------------------------------------------------+

您指定的最大聚合和平均聚合是针对多行的。

如果要保留所有原始行,请使用 Window 函数进行分区。

如果要减少每个分区中的行,则必须指定减少逻辑或过滤器。

import org.apache.spark.sql._
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._


object TestJob7 {

  def main (args: Array[String]): Unit = {

    val sparkSession = SparkSession
      .builder()
      .appName(this.getClass.getName.replace("$", ""))
      .master("local")
      .getOrCreate()

    val sc = sparkSession.sparkContext
    sc.setLogLevel("ERROR")

    import sparkSession.sqlContext.implicits._

    val rawDf = Seq(
      ("West",  "Apple",  2.0, 10),
      ("West",  "Apple",  3.0, 15),
      ("West",  "Orange", 5.0, 15),
      ("South", "Orange", 3.0, 9),
      ("South", "Orange", 6.0, 18),
      ("East",  "Milk",   5.0, 5)
    ).toDF("store", "prod", "amt", "units")


    rawDf.show(false)
    rawDf.printSchema

    val storeProdWindow = Window
      .partitionBy("store", "prod")

    val aggDf = rawDf
      .withColumn("max(amt)", max("amt").over(storeProdWindow))
      .withColumn("avg(units)", avg("units").over(storeProdWindow))

    aggDf.printSchema

    aggDf.show(false)
  }
}

这是结果,请注意它已经分组(window 洗牌到分区中)

+-----+------+---+-----+--------+----------+
|store|prod  |amt|units|max(amt)|avg(units)|
+-----+------+---+-----+--------+----------+
|South|Orange|3.0|9    |6.0     |13.5      |
|South|Orange|6.0|18   |6.0     |13.5      |
|West |Orange|5.0|15   |5.0     |15.0      |
|East |Milk  |5.0|5    |5.0     |5.0       |
|West |Apple |2.0|10   |3.0     |12.5      |
|West |Apple |3.0|15   |3.0     |12.5      |
+-----+------+---+-----+--------+----------+

聚合函数减少组内指定列的行值。 Yo 可以执行多个不同的聚合,从而在一次迭代中生成包含来自输​​入行的值的新列,完全使用 Dataframe 功能。如果您希望保留其他行值,则需要实施缩减逻辑来指定每个值来自的行。例如,将第一行的所有值都保留为年龄的最大值。为此,您可以使用 UDAF(用户定义的聚合函数)来减少组内的行。在示例中,我还在同一迭代中使用标准聚合函数聚合了最大金额和平均单位。

import org.apache.spark.sql._
import org.apache.spark.sql.functions._


object ReduceAggJob {

  def main (args: Array[String]): Unit = {

    val appName = this.getClass.getName.replace("$", "")
    println(s"appName: $appName")

    val sparkSession = SparkSession
      .builder()
      .appName(appName)
      .master("local")
      .getOrCreate()

    val sc = sparkSession.sparkContext
    sc.setLogLevel("ERROR")

    import sparkSession.sqlContext.implicits._

    val rawDf = Seq(
      ("West",  "Apple",  2.0, 10),
      ("West",  "Apple",  3.0, 15),
      ("West",  "Orange", 5.0, 15),
      ("West",  "Orange", 17.0, 15),
      ("South", "Orange", 3.0, 9),
      ("South", "Orange", 6.0, 18),
      ("East",  "Milk",   5.0, 5)
    ).toDF("store", "prod", "amt", "units")

    rawDf.printSchema
    rawDf.show(false)
    // Create an instance of UDAF GeometricMean.
    val maxAmtUdaf = new KeepRowWithMaxAmt

    // Keep the row with max amt
    val aggDf = rawDf
      .groupBy("store", "prod")
      .agg(
        max("amt"),
        avg("units"),
        maxAmtUdaf(
        col("store"),
        col("prod"),
        col("amt"),
        col("units")).as("KeepRowWithMaxAmt")
      )

    aggDf.printSchema
    aggDf.show(false)
  }
}

UDAF

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._


class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
  // This is the input fields for your aggregate function.
  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(
      StructField("store", StringType) ::
      StructField("prod", StringType) ::
      StructField("amt", DoubleType) ::
      StructField("units", IntegerType) :: Nil
    )

  // This is the internal fields you keep for computing your aggregate.
  override def bufferSchema: StructType = StructType(
    StructField("store", StringType) ::
    StructField("prod", StringType) ::
    StructField("amt", DoubleType) ::
    StructField("units", IntegerType) :: Nil
  )


  // This is the output type of your aggregation function.
  override def dataType: DataType =
    StructType((Array(
      StructField("store", StringType),
      StructField("prod", StringType),
      StructField("amt", DoubleType),
      StructField("units", IntegerType)
    )))

  override def deterministic: Boolean = true

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = ""
    buffer(1) = ""
    buffer(2) = 0.0
    buffer(3) = 0
  }

  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

    val amt = buffer.getAs[Double](2)
    val candidateAmt = input.getAs[Double](2)

    amt match {
      case a if a < candidateAmt =>
        buffer(0) = input.getAs[String](0)
        buffer(1) = input.getAs[String](1)
        buffer(2) = input.getAs[Double](2)
        buffer(3) = input.getAs[Int](3)
      case _ =>
    }
  }

  // This is how to merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    buffer1(0) = buffer2.getAs[String](0)
    buffer1(1) = buffer2.getAs[String](1)
    buffer1(2) = buffer2.getAs[Double](2)
    buffer1(3) = buffer2.getAs[Int](3)
  }

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = {
    buffer
  }
}