Spark ML 管道 api 保存不工作
Spark ML Pipeline api save not working
在 1.6 版中,管道 api 获得了一组新功能来保存和加载管道阶段。我尝试在训练分类器后将阶段保存到磁盘,稍后再次加载它以重用它并节省再次计算模型的工作量。
由于某些原因,当我保存模型时,该目录仅包含元数据目录。当我尝试再次加载它时,出现以下异常:
Exception in thread "main" java.lang.UnsupportedOperationException:
empty collection at
org.apache.spark.rdd.RDD$$anonfun$first.apply(RDD.scala:1330) at
org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:150)
at
org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:111)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:316) at
org.apache.spark.rdd.RDD.first(RDD.scala:1327) at
org.apache.spark.ml.util.DefaultParamsReader$.loadMetadata(ReadWrite.scala:284)
at
org.apache.spark.ml.tuning.CrossValidator$SharedReadWrite$.load(CrossValidator.scala:287)
at
org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelReader.load(CrossValidator.scala:393)
at
org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelReader.load(CrossValidator.scala:384)
at
org.apache.spark.ml.util.MLReadable$class.load(ReadWrite.scala:176)
at
org.apache.spark.ml.tuning.CrossValidatorModel$.load(CrossValidator.scala:368)
at
org.apache.spark.ml.tuning.CrossValidatorModel.load(CrossValidator.scala)
at
org.test.categoryminer.spark.SparkTextClassifierModelCache.get(SparkTextClassifierModelCache.java:34)
保存我使用的模型:crossValidatorModel.save("/tmp/my.model")
加载它我使用:CrossValidatorModel.load("/tmp/my.model")
我在调用 CrossValidatorModel 对象时调用 fit(dataframe) 时得到的 CrossValidatorModel 对象调用保存。
任何指针为什么它只保存元数据目录?
这肯定不会直接回答你的问题,但我个人并没有测试 1.6.0 中的新功能。
我正在使用专用功能来保存模型。
def saveCrossValidatorModel(model:CrossValidatorModel, path:String)
{
try {
val fileOut:FileOutputStream = new FileOutputStream(path)
val out:ObjectOutputStream = new ObjectOutputStream(fileOut)
out.writeObject(model)
out.close()
fileOut.close()
} catch {
case foe:FileNotFoundException =>
foe.printStackTrace()
case ioe:IOException =>
ioe.printStackTrace()
}
}
然后您可以用类似的方式阅读您的模型:
def loadCrossValidatorModel(path:String): CrossValidatorModel =
{
try {
val fileIn:FileInputStream = new FileInputStream(path)
val in:ObjectInputStream = new ObjectInputStream(fileIn)
val cvModel = in.readObject().asInstanceOf[CrossValidatorModel]
in.close()
fileIn.close()
cvModel
} catch {
case foe:FileNotFoundException =>
foe.printStackTrace()
case ioe:IOException =>
ioe.printStackTrace()
}
}
在 1.6 版中,管道 api 获得了一组新功能来保存和加载管道阶段。我尝试在训练分类器后将阶段保存到磁盘,稍后再次加载它以重用它并节省再次计算模型的工作量。
由于某些原因,当我保存模型时,该目录仅包含元数据目录。当我尝试再次加载它时,出现以下异常:
Exception in thread "main" java.lang.UnsupportedOperationException: empty collection at org.apache.spark.rdd.RDD$$anonfun$first.apply(RDD.scala:1330) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:150) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:111) at org.apache.spark.rdd.RDD.withScope(RDD.scala:316) at org.apache.spark.rdd.RDD.first(RDD.scala:1327) at org.apache.spark.ml.util.DefaultParamsReader$.loadMetadata(ReadWrite.scala:284) at org.apache.spark.ml.tuning.CrossValidator$SharedReadWrite$.load(CrossValidator.scala:287) at org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelReader.load(CrossValidator.scala:393) at org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelReader.load(CrossValidator.scala:384) at org.apache.spark.ml.util.MLReadable$class.load(ReadWrite.scala:176) at org.apache.spark.ml.tuning.CrossValidatorModel$.load(CrossValidator.scala:368) at org.apache.spark.ml.tuning.CrossValidatorModel.load(CrossValidator.scala) at org.test.categoryminer.spark.SparkTextClassifierModelCache.get(SparkTextClassifierModelCache.java:34)
保存我使用的模型:crossValidatorModel.save("/tmp/my.model")
加载它我使用:CrossValidatorModel.load("/tmp/my.model")
我在调用 CrossValidatorModel 对象时调用 fit(dataframe) 时得到的 CrossValidatorModel 对象调用保存。
任何指针为什么它只保存元数据目录?
这肯定不会直接回答你的问题,但我个人并没有测试 1.6.0 中的新功能。
我正在使用专用功能来保存模型。
def saveCrossValidatorModel(model:CrossValidatorModel, path:String)
{
try {
val fileOut:FileOutputStream = new FileOutputStream(path)
val out:ObjectOutputStream = new ObjectOutputStream(fileOut)
out.writeObject(model)
out.close()
fileOut.close()
} catch {
case foe:FileNotFoundException =>
foe.printStackTrace()
case ioe:IOException =>
ioe.printStackTrace()
}
}
然后您可以用类似的方式阅读您的模型:
def loadCrossValidatorModel(path:String): CrossValidatorModel =
{
try {
val fileIn:FileInputStream = new FileInputStream(path)
val in:ObjectInputStream = new ObjectInputStream(fileIn)
val cvModel = in.readObject().asInstanceOf[CrossValidatorModel]
in.close()
fileIn.close()
cvModel
} catch {
case foe:FileNotFoundException =>
foe.printStackTrace()
case ioe:IOException =>
ioe.printStackTrace()
}
}