Scala:在 Scala 中用于多项式曲线拟合的 SPARK UDAF,出现“type (char[]) cannot be converted to the string type”错误

Scala: SPARK UDAF for polynomial curve fitting in Scala, got " type (char[]) cannot be converted to the string type" error

我正在尝试对类似于下面的 Spark 数据框进行多项式曲线拟合(使用 SPARK 版本 2.4.0.7.1.5,Scala 版本 2.11.12(OpenJDK 64 位服务器 VM,Java 1.8.0_232)).

我为此写了一个UDAF,它可以注册,但在运行时出错。

我是 Scala 和 UDAF 的新手。你能帮我看看我的函数有什么问题吗?

谢谢,

示例 df

val n = 2

val data = Seq(
  (1,80.0,-0.361982467), (1,70.0,0.067847447),  (1,50.0,-0.196768255), 
  (1,40.0,-0.135489192), (1,65.0,0.005993648),  (1,75.0,0.037561161), 
  (1,60.0,-0.212658599), (1,55.0,-0.187080872), (1,85.0, 0.382061571),
  (2,80.0,-0.301982467), (2,70.0,0.097847447),  (2,50.0,-0.186768255), 
  (2,40.0,-0.105489192), (2,65.0,0.007993648),  (2,75.0,0.037561161), 
  (2,60.0,-0.226528599), (2,55.0,-0.170870872), (2,85.0, 0.320615718)
)

val df = data.toDF("id", "x","y")

UDAF代码

import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}


class Fitter extends UserDefinedAggregateFunction {
  /**
   * Polynomial curve fitting
   *   y = c + a1*x + a2*x^2 + ...+ an * x^n
   * parameters:
   *    x: Array[Double]
   *    y: Array[Double]
   *    n: Int, polynomial degree
   * Return:
   *  coeff: the fitted parameters [c, a1, a2,...,an]
   */

  private def polyCurveFitting= (x: Array[Double], y: Array[Double], n: Int) => {

    val obs = new WeightedObservedPoints()

    for (i <- 0 until x.size) {
      obs.add(x(i), y(i))
    }

    // Instantiate a second-degree polynomial fitter.
    val fitter = PolynomialCurveFitter.create(n)

    // Retrieve fitted parameters (coefficients of the polynomial function).
    val coeff = fitter.fit(obs.toList())

    coeff.mkString("|")
  }

  override def inputSchema: StructType =
    new StructType().add(StructField("x", DoubleType))
                    .add(StructField("y", DoubleType))
                    .add(StructField("n", IntegerType))

  override def bufferSchema: StructType =
    new StructType().add(StructField("x_", ArrayType(DoubleType, false)))
                    .add(StructField("y_", ArrayType(DoubleType, false)))
                    .add(StructField("n_", IntegerType))

  override def dataType: DataType = StringType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, Array[Double]())
    buffer.update(1, Array[Double]())
    buffer.update(2, 0)
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if(!input.isNullAt(0)) {
      buffer(0) = buffer.getSeq[Double](0).toArray :+ input.getAs[Double](0)
      buffer(1) = buffer.getSeq[Double](1).toArray :+ input.getAs[Double](1)
      buffer(2) = input.getAs[Int](2)
    }
  }

  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getSeq[Double](0).toArray ++ buffer2.getSeq[Double](0)
    buffer1(1) = buffer1.getSeq[Double](1).toArray ++ buffer2.getSeq[Double](1)
    buffer1(2) = buffer2.getAs[Int](2)
  }

def evaluate(buffer: Row): Array[Char] =
   polyCurveFitting(buffer.getSeq[Double](0).toArray,
                    buffer.getSeq[Double](1).toArray,
                    buffer.getAs[Int](2)).toArray
}


调用函数

val fitter_test = new Fitter()

spark.udf.register("fitter", fitter_test)

df.createOrReplaceTempView("test")

spark.sql("select fitter(x,y,2) from test group by id").show()

val df_poly = df.groupBy("id").agg(fitter($"x",$"y",lit(n)).as("estimated_parameters"))

df_poly.show()

预期输出(伪):

+---+-----------------------------------------------------------------+
| id|                                             estimated_parameters|
+---+-----------------------------------------------------------------+
|  1|"0.5034579587428405|-0.026916449551428016|2.6802822386554184E-4" |
|  2|"0.5344951514280016|-0.020286916457958744|2.6916469164575874E-4" |
+---+-----------------------------------------------------------------+

错误信息:

WARN scheduler.TaskSetManager: Lost task 18.0 in stage 7.0 (TID 27, -----.analytics.loc, executor 19): java.lang.IllegalArgumentException: The value ([C@52a57e78) of the type (char[]) cannot be converted to the string type
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$StringConverter$.toCatalystImpl(CatalystTypeConverters.scala:290)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$StringConverter$.toCatalystImpl(CatalystTypeConverters.scala:285)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter.apply(CatalystTypeConverters.scala:396)
    at org.apache.spark.sql.execution.aggregate.ScalaUDAF.eval(udaf.scala:444)
    at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateResultProjection.apply(AggregationIterator.scala:232)
    at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateResultProjection.apply(AggregationIterator.scala:224)
    at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:150)
    at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:266)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:257)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:858)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:858)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:123)
    at org.apache.spark.executor.Executor$TaskRunner$$anonfun.apply(Executor.scala:408)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1289)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)
ERROR scheduler.TaskSetManager: Task 18 in stage 7.0 failed 4 times; aborting job
org.apache.spark.SparkException: Job aborted due to stage failure: Task 18 in stage 7.0 failed 4 times, most recent failure: Lost task 18.3 in stage 7.0 (TID 52, --------.analytics.loc, executor 19): java.lang.IllegalArgumentException: The value ([C@4f761fc2) of the type (char[]) cannot be converted to the string type
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$StringConverter$.toCatalystImpl(CatalystTypeConverters.scala:290)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$StringConverter$.toCatalystImpl(CatalystTypeConverters.scala:285)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter.apply(CatalystTypeConverters.scala:396)
    at org.apache.spark.sql.execution.aggregate.ScalaUDAF.eval(udaf.scala:444)
    at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateResultProjection.apply(AggregationIterator.scala:232)
    at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateResultProjection.apply(AggregationIterator.scala:224)
    at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:150)
    at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:266)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:257)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:858)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:858)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:123)
    at org.apache.spark.executor.Executor$TaskRunner$$anonfun.apply(Executor.scala:408)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1289)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

Caused by: java.lang.IllegalArgumentException: The value ([C@4f761fc2) of the type (char[]) cannot be converted to the string type
  at org.apache.spark.sql.catalyst.CatalystTypeConverters$StringConverter$.toCatalystImpl(CatalystTypeConverters.scala:290)
  at org.apache.spark.sql.catalyst.CatalystTypeConverters$StringConverter$.toCatalystImpl(CatalystTypeConverters.scala:285)
  at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103)
  at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter.apply(CatalystTypeConverters.scala:396)
  at org.apache.spark.sql.execution.aggregate.ScalaUDAF.eval(udaf.scala:444)
  at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateResultProjection.apply(AggregationIterator.scala:232)
  at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$generateResultProjection.apply(AggregationIterator.scala:224)
  at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:150)
  at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29)
  at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:266)
  at org.apache.spark.sql.execution.SparkPlan$$anonfun.apply(SparkPlan.scala:257)
  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:858)
  at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$$anonfun$apply.apply(RDD.scala:858)
  at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
  at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
  at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
  at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
  at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:346)
  at org.apache.spark.rdd.RDD.iterator(RDD.scala:310)
  at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
  at org.apache.spark.scheduler.Task.run(Task.scala:123)
  at org.apache.spark.executor.Executor$TaskRunner$$anonfun.apply(Executor.scala:408)
  at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1289)
  at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:414)
  at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
  at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
  at java.lang.Thread.run(Thread.java:748)

我认为问题与方法 evaluate 的 return 值的类型有关。 Spark 编译器需要一个字符串,正如您在 dataType 方法中输入的那样,因此它会检测到该类型不匹配。如果在 evaluate 方法中删除 .toArray 和 return 字符串,该错误应该消失。