Spark 数据集等效于 scala 的 "collect" 采用部分函数

Spark Dataset equivalent for scala's "collect" taking a partial function

常规 scala 集合有一个漂亮的 collect 方法,它让我可以使用部分函数一次完成 filter-map 操作。在 spark Datasets?

上是否有等效的操作

我喜欢它有两个原因:


这里有一个例子来说明我的意思。假设我有一系列选项,我想提取和加倍定义的整数(Some 中的整数):

val input = Seq(Some(3), None, Some(-1), None, Some(4), Some(5)) 

方法一 - collect

input.collect {
  case Some(value) => value * 2
} 
// List(6, -2, 8, 10)

collect 使它在语法上非常整洁并且一次通过。

方法二 - filter-map

input.filter(_.isDefined).map(_.get * 2)

我可以将这种模式带到 spark 中,因为数据集和数据框具有类似的方法。

但我不太喜欢这个,因为 isDefinedget 对我来说似乎是代码的味道。有一个隐含的假设,即 map 仅接收 Somes。编译器无法验证这一点。在一个更大的示例中,开发人员更难发现该假设,并且开发人员可能会交换过滤器和映射,而不会出现语法错误。

方法 3 - fold* 操作

input.foldRight[List[Int]](Nil) {
  case (nextOpt, acc) => nextOpt match {
    case Some(next) => next*2 :: acc
    case None => acc
  }
}

我没有充分使用 spark 来了解 fold 是否有等价物,所以这可能有点切线。

反正模式匹配、折叠样板和列表的重建都乱七八糟,很难读。


所以总的来说,我发现 collect 语法最好,我希望 spark 有这样的东西。

RDDs和Datasets上定义的collect方法用于实现驱动程序中的数据。

尽管没有类似于 Collections API collect 方法的东西,但您的直觉是正确的:因为这两个操作都是惰性评估的,引擎有机会优化操作并将它们链接起来它们是在最大位置执行的。

对于您特别提到的用例,我建议您考虑 flatMap,它适用于 RDDs 和 Datasets:

// Assumes the usual spark-shell environment
// sc: SparkContext, spark: SparkSession
val collection = Seq(Some(1), None, Some(2), None, Some(3))
val rdd = sc.parallelize(collection)
val dataset = spark.createDataset(rdd)

// Both operations will yield `Array(2, 4, 6)`
rdd.flatMap(_.map(_ * 2)).collect
dataset.flatMap(_.map(_ * 2)).collect

// You can also express the operation in terms of a for-comprehension
(for (option <- rdd; n <- option) yield n * 2).collect
(for (option <- dataset; n <- option) yield n * 2).collect

// The same approach is valid for traditional collections as well
collection.flatMap(_.map(_ * 2))
for (option <- collection; n <- option) yield n * 2

编辑

正如在另一个问题中正确指出的那样,RDDs 实际上有 collect 方法,它通过应用部分函数来转换 RDD ,就像在普通集合中发生的那样。然而,正如 Spark documentation 指出的那样,"this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory."

为了完整起见:

RDD API 确实有这样的方法,所以总是可以选择将给定的Dataset / DataFrame转换为RDD,执行collect操作并转换回来,例如:

val dataset = Seq(Some(1), None, Some(2)).toDS()
val dsResult = dataset.rdd.collect { case Some(i) => i * 2 }.toDS()

但是,这可能比在数据集上使用地图和过滤器表现更差(原因在@stefanobaghino 的回答中解释)。

至于 DataFrame,这个特定示例(使用 Option)有些误导,因为转换为 DataFrame 实际上会将选项的 "flatenning" 转换为它们的值(或 null对于 None),因此等效表达式为:

val dataframe = Seq(Some(1), None, Some(2)).toDF("opt")
dataframe.withColumn("opt", $"opt".multiply(2)).filter(not(isnull($"opt")))

我认为,您对地图操作 "assume" 对其输入的任何担忧的影响较小。

我只是想通过包含一个 for 对案例 class 的理解示例来扩展 stefanobaghino 的回答,因为许多用例可能会涉及案例 classes。

此外,选项是 monad,这使得在这种情况下接受的答案非常简单,因为 for 巧妙地删除了 None 值,但这种方法不会扩展到 non-monads像案例 classes:

case class A(b: Boolean, i: Int, d: Double)

val collection = Seq(A(true, 3), A(false, 10), A(true, -1))
val rdd = ...
val dataset = ...

// Select out and double all the 'i' values where 'b' is true:
for {
  A(b, i, _) <- dataset
  if b
} yield i * 2

这里的答案是不正确的,至少以现在的Spark是这样。

RDD 实际上有一个 collect 方法,它采用部分函数并对数据应用过滤器和映射。这与无参数的 .collect() 方法完全不同。查看 Spark 源代码 RDD.scala @ line 955:

/**
 * Return an RDD that contains all matching values by applying `f`.
 */
def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope {
  val cleanF = sc.clean(f)
  filter(cleanF.isDefinedAt).map(cleanF)
}

与第 923 行 RDD.scala 中的无参数 .collect() 方法相反,这不会具体化 RDD 中的数据:

/**
 * Return an array that contains all of the elements in this RDD.
 */
def collect(): Array[T] = withScope {
  val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
  Array.concat(results: _*)
}

在文档中,请注意

def collect[U](f: PartialFunction[T, U]): RDD[U]

方法 没有 有一个关于正在加载到驱动程序内存中的数据的警告:

https://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.rdd.RDD@collect[U](f:PartialFunction[T,U])(implicitevidence:scala.reflect.ClassTag[U]):org.apache.spark.rdd.RDD[U]

让这些重载方法做完全不同的事情对 Spark 来说非常混乱。


编辑:我错了!我误解了这个问题,我们在谈论数据集而不是 RDD。不过,接受的答案是

"the Spark documentation points out, however, "this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory."

这是不正确的!调用 .collect() 的部分函数版本时,数据不会加载到驱动程序的内存中 - 只有在调用无参数版本时才会加载。调用 .collect(partial_function) 的性能应该与顺序调用 .filter() 和 .map() 的性能大致相同,如上面的源代码所示。

您始终可以创建自己的扩展方法:

implicit class DatasetOps[T](ds: Dataset[T]) {

  def collectt[U](pf: PartialFunction[T, U])(implicit enc: Encoder[U]): Dataset[U] = {
    ds.flatMap(pf.lift(_))
  }
}

这样:

// val ds = Dataset(1, 2, 3)
ds.collectt { case x if x % 2 == 1 => x * 3 }
// Dataset(3, 9)

请注意,不幸的是我无法将其命名为 collect(因此可怕的后缀 t),否则签名会(我认为)与现有的 Dataset#collect 方法将 Dataset 转换为 Array.