具有复杂输入参数的 Spark SQL UDF

Spark SQL UDF with complex input parameter

我正在尝试将 UDF 与结构数组的输入类型一起使用。 我有以下数据结构,这只是更大结构的相关部分

|--investments: array (nullable = true)
    |    |-- element: struct (containsNull = true)
    |    |    |-- funding_round: struct (nullable = true)
    |    |    |    |-- company: struct (nullable = true)
    |    |    |    |    |-- name: string (nullable = true)
    |    |    |    |    |-- permalink: string (nullable = true)
    |    |    |    |-- funded_day: long (nullable = true)
    |    |    |    |-- funded_month: long (nullable = true)
    |    |    |    |-- funded_year: long (nullable = true)
    |    |    |    |-- raised_amount: long (nullable = true)
    |    |    |    |-- raised_currency_code: string (nullable = true)
    |    |    |    |-- round_code: string (nullable = true)
    |    |    |    |-- source_description: string (nullable = true)
    |    |    |    |-- source_url: string (nullable = true)

我声明了案例类:

case class Company(name: String, permalink: String)
case class FundingRound(company: Company, funded_day: Long, funded_month: Long, funded_year: Long, raised_amount: Long, raised_currency_code: String, round_code: String, source_description: String, source_url: String)
case class Investments(funding_round: FundingRound)

UDF声明:

sqlContext.udf.register("total_funding", (investments:Seq[Investments])  => {
     val totals = investments.map(r => r.funding_round.raised_amount)
     totals.sum
})

当我执行以下转换时,结果符合预期

scala> sqlContext.sql("""select total_funding(investments) from companies""")
res11: org.apache.spark.sql.DataFrame = [_c0: bigint]

但是当像收集这样的动作执行时我有一个错误:

Executor: Exception in task 0.0 in stage 4.0 (TID 10)
java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to $line33.$read$$iwC$$iwC$Investments

感谢您的帮助。

您看到的错误应该是不言自明的。 Catalyst / SQL 类型和 Scala 类型之间存在严格的映射,可以在 the relevant section of the Spark SQL, DataFrames and Datasets Guide.

中找到

特别是 struct 类型转换为 o.a.s.sql.Row(在您的特定情况下,数据将公开为 Seq[Row])。

有多种方法可用于将数据公开为特定类型:

  • (user defined type) which has been removed in 2.0.0 and has no replacement for now.
  • 正在将 DataFrame 转换为 Dataset[T],其中 T 是所需的本地类型。

只有前一种方法可以适用于这种特定情况。

如果您想使用 UDF 访问 investments.funding_round.raised_amount,您需要这样的东西:

val getRaisedAmount = udf((investments: Seq[Row]) => scala.util.Try(
  investments.map(_.getAs[Row]("funding_round").getAs[Long]("raised_amount"))
).toOption)

但简单 select 应该更安全、更清洁:

df.select($"investments.funding_round.raised_amount")

我创建了一个简单的库,它根据输入类型参数为复杂的产品类型导出必要的编码器。

https://github.com/lesbroot/typedudf

import typedudf.TypedUdf
import typedudf.ParamEncoder._

case class Foo(x: Int, y: String)
val fooUdf = TypedUdf((foo: Foo) => foo.x + foo.y.length)
df.withColumn("sum", fooUdf($"foo"))