Spark 数据集等效于 scala 的 "collect" 采用部分函数
Spark Dataset equivalent for scala's "collect" taking a partial function
常规 scala 集合有一个漂亮的 collect
方法,它让我可以使用部分函数一次完成 filter-map
操作。在 spark Dataset
s?
上是否有等效的操作
我喜欢它有两个原因:
- 语法简单
- 它将
filter-map
样式操作减少到单次通过(尽管在 spark 中我猜有一些优化可以为您发现这些东西)
这里有一个例子来说明我的意思。假设我有一系列选项,我想提取和加倍定义的整数(Some
中的整数):
val input = Seq(Some(3), None, Some(-1), None, Some(4), Some(5))
方法一 - collect
input.collect {
case Some(value) => value * 2
}
// List(6, -2, 8, 10)
collect
使它在语法上非常整洁并且一次通过。
方法二 - filter-map
input.filter(_.isDefined).map(_.get * 2)
我可以将这种模式带到 spark 中,因为数据集和数据框具有类似的方法。
但我不太喜欢这个,因为 isDefined
和 get
对我来说似乎是代码的味道。有一个隐含的假设,即 map 仅接收 Some
s。编译器无法验证这一点。在一个更大的示例中,开发人员更难发现该假设,并且开发人员可能会交换过滤器和映射,而不会出现语法错误。
方法 3 - fold*
操作
input.foldRight[List[Int]](Nil) {
case (nextOpt, acc) => nextOpt match {
case Some(next) => next*2 :: acc
case None => acc
}
}
我没有充分使用 spark 来了解 fold 是否有等价物,所以这可能有点切线。
反正模式匹配、折叠样板和列表的重建都乱七八糟,很难读。
所以总的来说,我发现 collect
语法最好,我希望 spark 有这样的东西。
在RDD
s和Dataset
s上定义的collect
方法用于实现驱动程序中的数据。
尽管没有类似于 Collections API collect
方法的东西,但您的直觉是正确的:因为这两个操作都是惰性评估的,引擎有机会优化操作并将它们链接起来它们是在最大位置执行的。
对于您特别提到的用例,我建议您考虑 flatMap
,它适用于 RDD
s 和 Dataset
s:
// Assumes the usual spark-shell environment
// sc: SparkContext, spark: SparkSession
val collection = Seq(Some(1), None, Some(2), None, Some(3))
val rdd = sc.parallelize(collection)
val dataset = spark.createDataset(rdd)
// Both operations will yield `Array(2, 4, 6)`
rdd.flatMap(_.map(_ * 2)).collect
dataset.flatMap(_.map(_ * 2)).collect
// You can also express the operation in terms of a for-comprehension
(for (option <- rdd; n <- option) yield n * 2).collect
(for (option <- dataset; n <- option) yield n * 2).collect
// The same approach is valid for traditional collections as well
collection.flatMap(_.map(_ * 2))
for (option <- collection; n <- option) yield n * 2
编辑
正如在另一个问题中正确指出的那样,RDD
s 实际上有 collect
方法,它通过应用部分函数来转换 RDD
,就像在普通集合中发生的那样。然而,正如 Spark documentation 指出的那样,"this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory."
为了完整起见:
RDD API 确实有这样的方法,所以总是可以选择将给定的Dataset / DataFrame转换为RDD,执行collect
操作并转换回来,例如:
val dataset = Seq(Some(1), None, Some(2)).toDS()
val dsResult = dataset.rdd.collect { case Some(i) => i * 2 }.toDS()
但是,这可能比在数据集上使用地图和过滤器表现更差(原因在@stefanobaghino 的回答中解释)。
至于 DataFrame,这个特定示例(使用 Option
)有些误导,因为转换为 DataFrame 实际上会将选项的 "flatenning" 转换为它们的值(或 null
对于 None
),因此等效表达式为:
val dataframe = Seq(Some(1), None, Some(2)).toDF("opt")
dataframe.withColumn("opt", $"opt".multiply(2)).filter(not(isnull($"opt")))
我认为,您对地图操作 "assume" 对其输入的任何担忧的影响较小。
我只是想通过包含一个 for
对案例 class 的理解示例来扩展 stefanobaghino 的回答,因为许多用例可能会涉及案例 classes。
此外,选项是 monad,这使得在这种情况下接受的答案非常简单,因为 for
巧妙地删除了 None
值,但这种方法不会扩展到 non-monads像案例 classes:
case class A(b: Boolean, i: Int, d: Double)
val collection = Seq(A(true, 3), A(false, 10), A(true, -1))
val rdd = ...
val dataset = ...
// Select out and double all the 'i' values where 'b' is true:
for {
A(b, i, _) <- dataset
if b
} yield i * 2
这里的答案是不正确的,至少以现在的Spark是这样。
RDD 实际上有一个 collect 方法,它采用部分函数并对数据应用过滤器和映射。这与无参数的 .collect() 方法完全不同。查看 Spark 源代码 RDD.scala @ line 955:
/**
* Return an RDD that contains all matching values by applying `f`.
*/
def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope {
val cleanF = sc.clean(f)
filter(cleanF.isDefinedAt).map(cleanF)
}
与第 923 行 RDD.scala 中的无参数 .collect() 方法相反,这不会具体化 RDD 中的数据:
/**
* Return an array that contains all of the elements in this RDD.
*/
def collect(): Array[T] = withScope {
val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
Array.concat(results: _*)
}
在文档中,请注意
def collect[U](f: PartialFunction[T, U]): RDD[U]
方法 没有 有一个关于正在加载到驱动程序内存中的数据的警告:
让这些重载方法做完全不同的事情对 Spark 来说非常混乱。
编辑:我错了!我误解了这个问题,我们在谈论数据集而不是 RDD。不过,接受的答案是
"the Spark documentation points out, however, "this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory."
这是不正确的!调用 .collect() 的部分函数版本时,数据不会加载到驱动程序的内存中 - 只有在调用无参数版本时才会加载。调用 .collect(partial_function) 的性能应该与顺序调用 .filter() 和 .map() 的性能大致相同,如上面的源代码所示。
您始终可以创建自己的扩展方法:
implicit class DatasetOps[T](ds: Dataset[T]) {
def collectt[U](pf: PartialFunction[T, U])(implicit enc: Encoder[U]): Dataset[U] = {
ds.flatMap(pf.lift(_))
}
}
这样:
// val ds = Dataset(1, 2, 3)
ds.collectt { case x if x % 2 == 1 => x * 3 }
// Dataset(3, 9)
请注意,不幸的是我无法将其命名为 collect
(因此可怕的后缀 t
),否则签名会(我认为)与现有的 Dataset#collect
方法将 Dataset
转换为 Array
.
常规 scala 集合有一个漂亮的 collect
方法,它让我可以使用部分函数一次完成 filter-map
操作。在 spark Dataset
s?
我喜欢它有两个原因:
- 语法简单
- 它将
filter-map
样式操作减少到单次通过(尽管在 spark 中我猜有一些优化可以为您发现这些东西)
这里有一个例子来说明我的意思。假设我有一系列选项,我想提取和加倍定义的整数(Some
中的整数):
val input = Seq(Some(3), None, Some(-1), None, Some(4), Some(5))
方法一 - collect
input.collect {
case Some(value) => value * 2
}
// List(6, -2, 8, 10)
collect
使它在语法上非常整洁并且一次通过。
方法二 - filter-map
input.filter(_.isDefined).map(_.get * 2)
我可以将这种模式带到 spark 中,因为数据集和数据框具有类似的方法。
但我不太喜欢这个,因为 isDefined
和 get
对我来说似乎是代码的味道。有一个隐含的假设,即 map 仅接收 Some
s。编译器无法验证这一点。在一个更大的示例中,开发人员更难发现该假设,并且开发人员可能会交换过滤器和映射,而不会出现语法错误。
方法 3 - fold*
操作
input.foldRight[List[Int]](Nil) {
case (nextOpt, acc) => nextOpt match {
case Some(next) => next*2 :: acc
case None => acc
}
}
我没有充分使用 spark 来了解 fold 是否有等价物,所以这可能有点切线。
反正模式匹配、折叠样板和列表的重建都乱七八糟,很难读。
所以总的来说,我发现 collect
语法最好,我希望 spark 有这样的东西。
在RDD
s和Dataset
s上定义的collect
方法用于实现驱动程序中的数据。
尽管没有类似于 Collections API collect
方法的东西,但您的直觉是正确的:因为这两个操作都是惰性评估的,引擎有机会优化操作并将它们链接起来它们是在最大位置执行的。
对于您特别提到的用例,我建议您考虑 flatMap
,它适用于 RDD
s 和 Dataset
s:
// Assumes the usual spark-shell environment
// sc: SparkContext, spark: SparkSession
val collection = Seq(Some(1), None, Some(2), None, Some(3))
val rdd = sc.parallelize(collection)
val dataset = spark.createDataset(rdd)
// Both operations will yield `Array(2, 4, 6)`
rdd.flatMap(_.map(_ * 2)).collect
dataset.flatMap(_.map(_ * 2)).collect
// You can also express the operation in terms of a for-comprehension
(for (option <- rdd; n <- option) yield n * 2).collect
(for (option <- dataset; n <- option) yield n * 2).collect
// The same approach is valid for traditional collections as well
collection.flatMap(_.map(_ * 2))
for (option <- collection; n <- option) yield n * 2
编辑
正如在另一个问题中正确指出的那样,RDD
s 实际上有 collect
方法,它通过应用部分函数来转换 RDD
,就像在普通集合中发生的那样。然而,正如 Spark documentation 指出的那样,"this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory."
为了完整起见:
RDD API 确实有这样的方法,所以总是可以选择将给定的Dataset / DataFrame转换为RDD,执行collect
操作并转换回来,例如:
val dataset = Seq(Some(1), None, Some(2)).toDS()
val dsResult = dataset.rdd.collect { case Some(i) => i * 2 }.toDS()
但是,这可能比在数据集上使用地图和过滤器表现更差(原因在@stefanobaghino 的回答中解释)。
至于 DataFrame,这个特定示例(使用 Option
)有些误导,因为转换为 DataFrame 实际上会将选项的 "flatenning" 转换为它们的值(或 null
对于 None
),因此等效表达式为:
val dataframe = Seq(Some(1), None, Some(2)).toDF("opt")
dataframe.withColumn("opt", $"opt".multiply(2)).filter(not(isnull($"opt")))
我认为,您对地图操作 "assume" 对其输入的任何担忧的影响较小。
我只是想通过包含一个 for
对案例 class 的理解示例来扩展 stefanobaghino 的回答,因为许多用例可能会涉及案例 classes。
此外,选项是 monad,这使得在这种情况下接受的答案非常简单,因为 for
巧妙地删除了 None
值,但这种方法不会扩展到 non-monads像案例 classes:
case class A(b: Boolean, i: Int, d: Double)
val collection = Seq(A(true, 3), A(false, 10), A(true, -1))
val rdd = ...
val dataset = ...
// Select out and double all the 'i' values where 'b' is true:
for {
A(b, i, _) <- dataset
if b
} yield i * 2
这里的答案是不正确的,至少以现在的Spark是这样。
RDD 实际上有一个 collect 方法,它采用部分函数并对数据应用过滤器和映射。这与无参数的 .collect() 方法完全不同。查看 Spark 源代码 RDD.scala @ line 955:
/**
* Return an RDD that contains all matching values by applying `f`.
*/
def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = withScope {
val cleanF = sc.clean(f)
filter(cleanF.isDefinedAt).map(cleanF)
}
与第 923 行 RDD.scala 中的无参数 .collect() 方法相反,这不会具体化 RDD 中的数据:
/**
* Return an array that contains all of the elements in this RDD.
*/
def collect(): Array[T] = withScope {
val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
Array.concat(results: _*)
}
在文档中,请注意
def collect[U](f: PartialFunction[T, U]): RDD[U]
方法 没有 有一个关于正在加载到驱动程序内存中的数据的警告:
让这些重载方法做完全不同的事情对 Spark 来说非常混乱。
编辑:我错了!我误解了这个问题,我们在谈论数据集而不是 RDD。不过,接受的答案是
"the Spark documentation points out, however, "this method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver's memory."
这是不正确的!调用 .collect() 的部分函数版本时,数据不会加载到驱动程序的内存中 - 只有在调用无参数版本时才会加载。调用 .collect(partial_function) 的性能应该与顺序调用 .filter() 和 .map() 的性能大致相同,如上面的源代码所示。
您始终可以创建自己的扩展方法:
implicit class DatasetOps[T](ds: Dataset[T]) {
def collectt[U](pf: PartialFunction[T, U])(implicit enc: Encoder[U]): Dataset[U] = {
ds.flatMap(pf.lift(_))
}
}
这样:
// val ds = Dataset(1, 2, 3)
ds.collectt { case x if x % 2 == 1 => x * 3 }
// Dataset(3, 9)
请注意,不幸的是我无法将其命名为 collect
(因此可怕的后缀 t
),否则签名会(我认为)与现有的 Dataset#collect
方法将 Dataset
转换为 Array
.