spark如何解释reduce中的列类型

How does spark interprets type of a column in reduce

我有以下table

DEST_COUNTRY_NAME   ORIGIN_COUNTRY_NAME count
United States       Romania             15
United States       Croatia             1
United States       Ireland             344
Egypt               United States       15  

table 表示为数据集。

scala> dataDS
res187: org.apache.spark.sql.Dataset[FlightData] = [DEST_COUNTRY_NAME: string, ORIGIN_COUNTRY_NAME: string ... 1 more field]

dataDS 的架构是

scala> dataDS.printSchema;
root
 |-- DEST_COUNTRY_NAME: string (nullable = true)
 |-- ORIGIN_COUNTRY_NAME: string (nullable = true)
 |-- count: integer (nullable = true)

我想对 count 列的所有值求和。我想我可以使用 Datasetreduce 方法来做到这一点。

我以为我可以执行以下操作但出现错误

scala> (dataDS.select(col("count"))).reduce((acc,n)=>acc+n);
<console>:38: error: type mismatch;
 found   : org.apache.spark.sql.Row
 required: String
       (dataDS.select(col("count"))).reduce((acc,n)=>acc+n);
                                                         ^

为了使代码正常工作,我必须明确指定 countInt,即使在模式中它是 Int

scala> (dataDS.select(col("count").as[Int])).reduce((acc,n)=>acc+n);

为什么我必须明确指定 count 的类型?为什么 Scala 的 type inference 不起作用?事实上,中间 Dataset 的模式也将 count 推断为 Int.

dataDS.select(col("count")).printSchema;
root
 |-- count: integer (nullable = true)

我认为你需要换一种方式。我将假设 FlightData 是 case class 与上述模式。因此,解决方案是使用 map 和 reduce 如下

val totalSum = dataDS.map(_.count).reduce(_+_) //this line replace the above error as col("count") can't be selected.

已更新:推理问题与数据集无关,实际上,当您使用select时,您将处理Dataframe(如果加入则相同)不是静态类型的架构,您将失去案例的功能 class。例如,select 的类型将是 Dataframe 而不是 Dataset,因此您将无法推断类型。

val x: DataFrame = dataDS.select('count)
val x: Dataset[Int] = dataDS.map(_.count)

此外,从这个 要从 Column 中获得 TypedColumn,您只需使用 myCol.as[T].

我做了一个简单的例子来重现代码和数据。

import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SparkSession}

object EntryMainPoint extends App {

  //val warehouseLocation = "file:${system:user.dir}/spark-warehouse"
  val spark = SparkSession
    .builder()
    .master("local[*]")
    .appName("SparkSessionZipsExample")
    //.config("spark.sql.warehouse.dir", warehouseLocation)
    .getOrCreate()

  val someData = Seq(
    Row("United States", "Romania", 15),
    Row("United States", "Croatia", 1),
    Row("United States", "Ireland", 344),
    Row("Egypt", "United States", 15)
  )


  val flightDataSchema = List(
    StructField("DEST_COUNTRY_NAME", StringType, true),
    StructField("ORIGIN_COUNTRY_NAME", StringType, true),
    StructField("count", IntegerType, true)
  )

  case class FlightData(DEST_COUNTRY_NAME: String, ORIGIN_COUNTRY_NAME: String, count: Int)
  import spark.implicits._

  val dataDS = spark.createDataFrame(
    spark.sparkContext.parallelize(someData),
    StructType(flightDataSchema)
  ).as[FlightData]

  val totalSum = dataDS.map(_.count).reduce(_+_) //this line replace the above error as col("count") can't be selected.
  println("totalSum = " + totalSum)


  dataDS.printSchema()
  dataDS.show()


}

输出低于

totalSum = 375

root
 |-- DEST_COUNTRY_NAME: string (nullable = true)
 |-- ORIGIN_COUNTRY_NAME: string (nullable = true)
 |-- count: integer (nullable = true)

+-----------------+-------------------+-----+
|DEST_COUNTRY_NAME|ORIGIN_COUNTRY_NAME|count|
+-----------------+-------------------+-----+
|    United States|            Romania|   15|
|    United States|            Croatia|    1|
|    United States|            Ireland|  344|
|            Egypt|      United States|   15|
+-----------------+-------------------+-----+

注意:您可以使用以下方式从数据集中执行 selection

val countColumn = dataDS.select('count) //or map(_.count)

你也可以看看这个

只需按照类型或查看编译器消息即可。

  • 您从 Dataset[FlightData] 开始。

  • 你称它为 select 并以 col("count") 作为参数。 col(_) returns Column

  • The only variant of Dataset.select which takes Column returns DataFrame which is an alias for Dataset[Row].

  • Dataset.reduce 有两种变体,一种采用 ReduceFunction[T],第二种采用 (T, T) => T,其中 T 是 [= 的类型构造函数参数29=],即Dataset[T](acc,n)=>acc+n 函数是 Scala 匿名函数,因此适用第二个版本。

  • 展开:

    (dataDS.select(col("count")): Dataset[Row]).reduce((acc: Row, n: Row) => acc + n): Row
    

    设置约束 - 函数采用 RowRow 和 returns Row.

  • Row没有+方法,所以唯一满足

    的选项
    (acc: ???, n: Row) => acc + n)
    

    就是用String(可以+AnyString.

    但是这不满足完整的表达式 - 因此出现错误。

  • 你已经想到可以使用

    dataDS.select(col("count").as[Int]).reduce((acc, n) => acc + n)
    

    其中 col("count").as[Int]TypedColumn[Row, Int]corresponding select returns Dataset[Int].

    同样你可以

    dataDS.select(col("count")).as[Int].reduce((acc, n) => acc + n)
    

    dataDS.toDF.map(_.getAs[Int]("count")).reduce((acc, n) => acc + n)
    

    在所有情况下

    .reduce((acc, n) => acc + n)
    

    正在(Int, Int) => Int.