来自 generateLinearRDD 的数据集的逻辑回归失败 java.lang.IllegalArgumentException

Logistic Regression on a Dataset from generateLinearRDD fails with java.lang.IllegalArgumentException

因此,作为概念证明,我尝试使用来自 LinearDataGenerator.generateLinearRDD 的示例数据生成 DataFrame,然后对其执行逻辑回归。

假设 generateLinearRDD 会生成适合执行线性回归的数据,我将其插入带有 Binarizer 的管道中以创建适合逻辑回归的阈值列。

我的代码如下:

import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils}
import org.apache.spark.ml.feature.Binarizer
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}

// databricks users can comment out lines between here...
  val spark = SparkSession
    .builder()
    .appName("Java Spark SQL basic example")
    .config("spark.master", "local")
    .getOrCreate()

  import spark.implicits._
// ...and here

  val data = {
    val tmp = LinearDataGenerator.generateLinearRDD(spark.sparkContext, 10000, 4, 0.05).toDF()
    MLUtils.convertVectorColumnsToML(tmp, "features").withColumnRenamed("label", "continuousLabel")
  }

  val binarizer = new Binarizer()
    .setInputCol("continuousLabel")
    .setOutputCol("label")
    .setThreshold(0)

  val logisticRegression = new LogisticRegression()

  val pipeline = new Pipeline()
      .setStages(Array(binarizer, logisticRegression))

   val pipelineModel = pipeline.fit(data)

   println(pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel].binarySummary.accuracy)

异常的堆栈跟踪如下所示:

Exception in thread "main" java.lang.IllegalArgumentException
    at org.apache.xbean.asm5.ClassReader.<init>(Unknown Source)
    at org.apache.xbean.asm5.ClassReader.<init>(Unknown Source)
    at org.apache.xbean.asm5.ClassReader.<init>(Unknown Source)
    at org.apache.spark.util.ClosureCleaner$.getClassReader(ClosureCleaner.scala:46)
    at org.apache.spark.util.FieldAccessFinder$$anon$$anonfun$visitMethodInsn.apply(ClosureCleaner.scala:449)
    at org.apache.spark.util.FieldAccessFinder$$anon$$anonfun$visitMethodInsn.apply(ClosureCleaner.scala:432)
    at scala.collection.TraversableLike$WithFilter$$anonfun$foreach.apply(TraversableLike.scala:733)
    at scala.collection.mutable.HashMap$$anon$$anonfun$foreach.apply(HashMap.scala:134)
    at scala.collection.mutable.HashMap$$anon$$anonfun$foreach.apply(HashMap.scala:134)
    at scala.collection.mutable.HashTable$class.foreachEntry(HashTable.scala:236)
    at scala.collection.mutable.HashMap.foreachEntry(HashMap.scala:40)
    at scala.collection.mutable.HashMap$$anon.foreach(HashMap.scala:134)
    at scala.collection.TraversableLike$WithFilter.foreach(TraversableLike.scala:732)
    at org.apache.spark.util.FieldAccessFinder$$anon.visitMethodInsn(ClosureCleaner.scala:432)
    at org.apache.xbean.asm5.ClassReader.a(Unknown Source)
    at org.apache.xbean.asm5.ClassReader.b(Unknown Source)
    at org.apache.xbean.asm5.ClassReader.accept(Unknown Source)
    at org.apache.xbean.asm5.ClassReader.accept(Unknown Source)
    at org.apache.spark.util.ClosureCleaner$$anonfun$org$apache$spark$util$ClosureCleaner$$clean.apply(ClosureCleaner.scala:262)
    at org.apache.spark.util.ClosureCleaner$$anonfun$org$apache$spark$util$ClosureCleaner$$clean.apply(ClosureCleaner.scala:261)
    at scala.collection.immutable.List.foreach(List.scala:392)
    at org.apache.spark.util.ClosureCleaner$.org$apache$spark$util$ClosureCleaner$$clean(ClosureCleaner.scala:261)
    at org.apache.spark.util.ClosureCleaner$.clean(ClosureCleaner.scala:159)
    at org.apache.spark.SparkContext.clean(SparkContext.scala:2299)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2073)
    at org.apache.spark.SparkContext.runJob(SparkContext.scala:2099)
    at org.apache.spark.rdd.RDD$$anonfun$collect.apply(RDD.scala:939)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
    at org.apache.spark.rdd.RDD.collect(RDD.scala:938)
    at org.apache.spark.rdd.PairRDDFunctions$$anonfun$collectAsMap.apply(PairRDDFunctions.scala:743)
    at org.apache.spark.rdd.PairRDDFunctions$$anonfun$collectAsMap.apply(PairRDDFunctions.scala:742)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
    at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
    at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
    at org.apache.spark.rdd.PairRDDFunctions.collectAsMap(PairRDDFunctions.scala:742)
    at org.apache.spark.mllib.evaluation.MulticlassMetrics.tpByClass$lzycompute(MulticlassMetrics.scala:48)
    at org.apache.spark.mllib.evaluation.MulticlassMetrics.tpByClass(MulticlassMetrics.scala:44)
    at org.apache.spark.mllib.evaluation.MulticlassMetrics.accuracy$lzycompute(MulticlassMetrics.scala:168)
    at org.apache.spark.mllib.evaluation.MulticlassMetrics.accuracy(MulticlassMetrics.scala:168)
    at org.apache.spark.ml.classification.LogisticRegressionSummary$class.accuracy(LogisticRegression.scala:1445)
    at org.apache.spark.ml.classification.LogisticRegressionSummaryImpl.accuracy(LogisticRegression.scala:1641)
    at crossvalidation_graphs$.delayedEndpoint$crossvalidation_graphs(crossvalidation_graphs.scala:35)
    at crossvalidation_graphs$delayedInit$body.apply(crossvalidation_graphs.scala:9)
    at scala.Function0$class.apply$mcV$sp(Function0.scala:34)
    at scala.runtime.AbstractFunction0.apply$mcV$sp(AbstractFunction0.scala:12)
    at scala.App$$anonfun$main.apply(App.scala:76)
    at scala.App$$anonfun$main.apply(App.scala:76)
    at scala.collection.immutable.List.foreach(List.scala:392)
    at scala.collection.generic.TraversableForwarder$class.foreach(TraversableForwarder.scala:35)
    at scala.App$class.main(App.scala:76)
    at crossvalidation_graphs$.main(crossvalidation_graphs.scala:9)
    at crossvalidation_graphs.main(crossvalidation_graphs.scala)

我的架构目前如下所示:

root
 |-- continuousLabel: double (nullable = false)
 |-- features: vector (nullable = true)

我是 运行 Spark 2.3.1 和 Scala 2.11.12

this guy 类似,我的实际问题是我使用的是 Java 10 而不是 Java 8。当我切换回 Java 8 时,我的代码有效没问题。