如何在 MLReader 上创建通用函数

How can I make a function generic on an MLReader

我正在使用 Spark 1.6.3。这里有两个函数做同样的事情:

def modelFromBytesCV(modelArray: Array[Byte]): CountVectorizerModel = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  Files.write(tempPath, modelArray)
  CountVectorizerModel.read.load(tempPath.toString)
}

def modelFromBytesIDF(modelArray: Array[Byte]): IDFModel = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  Files.write(tempPath, modelArray)
  IDFModel.read.load(tempPath.toString)
}

我想让这些函数通用。我挂断的是 CountVectorizerModel 对象和 IDFModel 之间的共同特征是 MLReadable[T],它本身必须采用 CountVectorizerModel 或 IDFModel 作为类型。这是一种递归父 class 循环,我无法找到解决方案。

相比之下,通用模型编写器很容易,因为 MLWritable 是我感兴趣的所有模型扩展的共同特征:

def modelToBytes[M <: MLWritable](model: M): Array[Byte] = {
  val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
  model.write.overwrite().save(tempPath.toString)
  Files.readAllBytes(tempPath)
}

如何制作一个通用的 reader 将 spark-ml 模型转换为字节数组?

要使其正常工作,您需要访问特定的 MlReadable 对象。

import org.apache.spark.ml.util.MLReadable

def modelFromBytes[M](obj: MLReadable[M], modelArray: Array[Byte]): M = {
  val tempPath: Path = ???
  ...
  obj.read.load(tempPath.toString)
}

以后可以用作:

val bytes: Array[Byte] = ???
modelFromBytes(CountVectorizerModel, bytes)

请注意,尽管第一次出现,但这里没有任何递归 - MLReadable[M] 指的是伴随对象,而不是 class 本身。因此,例如 CountVectorizerModel object is MLReadable, while CountVectorizeModel class 不是。

在内部,Spark MLReader 以不同的方式处理此问题 - it creates an instance of the class using reflection, and then sets its Params。然而这条路在这里对你不是很有用*。

如果需要与当前 API 兼容,您可以尝试使可读对象隐式化:

def modelFromBytes[M](modelArray: Array[Byte])(implicit obj: MLReadable[M]): M = {
  ...
}

然后是

implicit val readable: MLReadable[CountVectorizerModel] = CountVectorizerModel

modelFromBytes[CountVectorizerModel](bytes)

* 从技术上讲,可以通过反射获得伴生对象

def modelFromBytesCV[M <: MLWritable](
    modelArray: Array[Byte])(implicit ct: ClassTag[M]): M = {
  val tempPath: Path = ???
  ...
  val cls = Class.forName(ct.runtimeClass.getName + "$");
  cls.getField("MODULE$").get(cls).asInstanceOf[MLReadable[M]]
    .read.load(tempPath.toString)) 
}

但我认为这不值得在这里探索。特别是我们在这里不能真正提供严格的类型界限——使用 MLWritable 是一种限制人为错误的技巧,但对编译器来说毫无用处。