在 Spark 数据集中创建总数为 运行 的列

create column with a running total in a Spark Dataset

假设我们有一个包含两列的 Spark 数据集,例如索引和值,按第一列(索引)排序。

((1, 100), (2, 110), (3, 90), ...)

我们想要一个包含第三列的数据集,其中第二列(值)中的值总计 运行。

((1, 100, 100), (2, 110, 210), (3, 90, 300), ...)

关于如何一次通过数据有效地执行此操作的任何建议?或者是否有任何可用于此的固定 CDF 类型函数?

如果需要,Dataset 可以转换为 Dataframe 或 RDD 来完成任务,但它必须保持分布式数据结构。也就是说,它不能简单地收集并转换为数组或序列,并且不使用可变变量(仅 val,不使用 var)。

but it will have to remain a distributed data structure.

不幸的是,您所说的您想要做的事情在 Spark 中是不可能的。如果您愿意将数据集重新分区到单个分区(实际上是将其合并到单个主机上),您可以轻松地编写一个函数来执行您想要的操作,将增加的值保留为一个字段。

由于 Spark 函数在执行时不会通过网络共享状态,因此无法创建共享状态,您需要保持数据集完全分布。

如果您愿意放宽您的要求并允许在一台主机上一次性整合和读取数据,那么您可以通过重新分区到单个分区并应用一个函数来做您想做的事。这不会将数据拉到驱动程序上(将其保存在 HDFS/the 集群中),但仍会在单个执行程序上连续计算输出。例如:

package com.github.nevernaptitsa

import java.io.Serializable
import java.util

import org.apache.spark.sql.{Encoders, SparkSession}

object SparkTest {

  class RunningSum extends Function[Int, Tuple2[Int, Int]] with Serializable {
    private var runningSum = 0
    override def apply(v1: Int): Tuple2[Int, Int] = {
      runningSum+=v1
      return (v1, runningSum)
    }
  }

  def main(args: Array[String]): Unit ={
    val session = SparkSession.builder()
      .appName("runningSumTest")
      .master("local[*]")
      .getOrCreate()
    import session.implicits._
    session.createDataset(Seq(1,2,3,4,5))
      .repartition(1)
      .map(new RunningSum)
      .show(5)
    session.createDataset(Seq(1,2,3,4,5))
      .map(new RunningSum)
      .show(5)
  }

}

这里的两个语句显示了不同的输出,第一个提供了正确的输出(串行,因为调用了 repartition(1)),第二个提供了错误的输出,因为结果是并行计算的。

第一个语句的结果:

+---+---+
| _1| _2|
+---+---+
|  1|  1|
|  2|  3|
|  3|  6|
|  4| 10|
|  5| 15|
+---+---+

第二个语句的结果:

+---+---+
| _1| _2|
+---+---+
|  1|  1|
|  2|  2|
|  3|  3|
|  4|  4|
|  5|  9|
+---+---+

一位同事提出了以下依赖于 RDD.mapPartitionsWithIndex() 方法的建议。 (据我所知,其他数据结构不提供这种对其分区索引的引用。)

val data = sc.parallelize((1 to 5))  // sc is the SparkContext
val partialSums = data.mapPartitionsWithIndex{ (i, values) => 
    Iterator((i, values.sum))
}.collect().toMap  // will in general have size other than data.count
val cumSums = data.mapPartitionsWithIndex{ (i, values) => 
    val prevSums = (0 until i).map(partialSums).sum
    values.scanLeft(prevSums)(_+_).drop(1)
}