在 Spark dataframe udf 中,像 struct(col1,col2) 这样的函数参数类型是什么?

In Spark dataframe udf, what is the type of function parameters which like struct(col1,col2)?

背景:

我有一个包含三列的数据框:id, x, y。 x,y 是 Double。

所以现在 df 只有两列:id ,coordinate.

我认为坐标的数据类型是collection.mutable.WrappedArray[(Double,Double)]。 所以我把它传给了udf。但是,数据类型是错误的。 运行 代码时出现错误。我不知道为什么。 struct(col1,col2) 的真实数据类型是什么?或者有其他方法可以轻松获得正确答案吗?

这是代码:

def getMedianPoint = udf((array1: collection.mutable.WrappedArray[(Double,Double)]) => {  
    var l = (array1.length/2)
    var c = array1(l)
    val x = c._1.asInstanceOf[Double]
    val y = c._2.asInstanceOf[Double]
    (x,y)
})

df.withColumn("coordinate",struct(col("x"),col("y")))
  .groupBy(col("id"))
  .agg(collect_list("coordinate").as("coordinate")
  .withColumn("median",getMedianPoint(col("coordinate")))

非常感谢!

I think the datatype of coordinate is collection.mutable.WrappedArray[(Double,Double)]

是的,你说的完全正确您在 udf 函数中定义为数据类型的内容以及作为参数传递的内容也是正确的。但是主要问题是结构列的键名。因为您一定遇到了以下问题

cannot resolve 'UDF(coordinate)' due to data type mismatch: argument 1 requires array> type, however, 'coordinate' is of array> type.;;

只需使用 alias 将结构键 重命名为

,错误就会消失
df.withColumn("coordinate",struct(col("x").as("_1"),col("y").as("_2")))
  .groupBy(col("id"))
  .agg(collect_list("coordinate").as("coordinate"))
    .withColumn("median",getMedianPoint(col("coordinate")))

以便键名匹配。

但是

这将在

引起另一个问题
  var c = array1(l)

Caused by: java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema cannot be cast to scala.Tuple2

所以我建议您将 udf 函数更改为

import org.apache.spark.sql.functions._

def getMedianPoint = udf((array1: Seq[Row]) => {
  var l = (array1.length/2)
  (array1(l)(0).asInstanceOf[Double], array1(l)(1).asInstanceOf[Double])
})

因此您甚至不需要使用 alias。所以完整的解决方案是

import org.apache.spark.sql.functions._

def getMedianPoint = udf((array1: Seq[Row]) => {
  var l = (array1.length/2)
  (array1(l)(0).asInstanceOf[Double], array1(l)(1).asInstanceOf[Double])
})

df.withColumn("coordinate",struct(col("x"),col("y")))
  .groupBy(col("id"))
  .agg(collect_list("coordinate").as("coordinate"))
    .withColumn("median",getMedianPoint(col("coordinate")))
  .show(false)

希望回答对你有帮助