如何在 MLReader 上创建通用函数
How can I make a function generic on an MLReader
我正在使用 Spark 1.6.3。这里有两个函数做同样的事情:
def modelFromBytesCV(modelArray: Array[Byte]): CountVectorizerModel = {
val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
Files.write(tempPath, modelArray)
CountVectorizerModel.read.load(tempPath.toString)
}
def modelFromBytesIDF(modelArray: Array[Byte]): IDFModel = {
val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
Files.write(tempPath, modelArray)
IDFModel.read.load(tempPath.toString)
}
我想让这些函数通用。我挂断的是 CountVectorizerModel 对象和 IDFModel 之间的共同特征是 MLReadable[T],它本身必须采用 CountVectorizerModel 或 IDFModel 作为类型。这是一种递归父 class 循环,我无法找到解决方案。
相比之下,通用模型编写器很容易,因为 MLWritable 是我感兴趣的所有模型扩展的共同特征:
def modelToBytes[M <: MLWritable](model: M): Array[Byte] = {
val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
model.write.overwrite().save(tempPath.toString)
Files.readAllBytes(tempPath)
}
如何制作一个通用的 reader 将 spark-ml 模型转换为字节数组?
要使其正常工作,您需要访问特定的 MlReadable
对象。
import org.apache.spark.ml.util.MLReadable
def modelFromBytes[M](obj: MLReadable[M], modelArray: Array[Byte]): M = {
val tempPath: Path = ???
...
obj.read.load(tempPath.toString)
}
以后可以用作:
val bytes: Array[Byte] = ???
modelFromBytes(CountVectorizerModel, bytes)
请注意,尽管第一次出现,但这里没有任何递归 - MLReadable[M]
指的是伴随对象,而不是 class 本身。因此,例如 CountVectorizerModel
object is MLReadable
, while CountVectorizeModel
class 不是。
在内部,Spark MLReader
以不同的方式处理此问题 - it creates an instance of the class using reflection, and then sets its Params
。然而这条路在这里对你不是很有用*。
如果需要与当前 API 兼容,您可以尝试使可读对象隐式化:
def modelFromBytes[M](modelArray: Array[Byte])(implicit obj: MLReadable[M]): M = {
...
}
然后是
implicit val readable: MLReadable[CountVectorizerModel] = CountVectorizerModel
modelFromBytes[CountVectorizerModel](bytes)
* 从技术上讲,可以通过反射获得伴生对象
def modelFromBytesCV[M <: MLWritable](
modelArray: Array[Byte])(implicit ct: ClassTag[M]): M = {
val tempPath: Path = ???
...
val cls = Class.forName(ct.runtimeClass.getName + "$");
cls.getField("MODULE$").get(cls).asInstanceOf[MLReadable[M]]
.read.load(tempPath.toString))
}
但我认为这不值得在这里探索。特别是我们在这里不能真正提供严格的类型界限——使用 MLWritable
是一种限制人为错误的技巧,但对编译器来说毫无用处。
我正在使用 Spark 1.6.3。这里有两个函数做同样的事情:
def modelFromBytesCV(modelArray: Array[Byte]): CountVectorizerModel = {
val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
Files.write(tempPath, modelArray)
CountVectorizerModel.read.load(tempPath.toString)
}
def modelFromBytesIDF(modelArray: Array[Byte]): IDFModel = {
val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
Files.write(tempPath, modelArray)
IDFModel.read.load(tempPath.toString)
}
我想让这些函数通用。我挂断的是 CountVectorizerModel 对象和 IDFModel 之间的共同特征是 MLReadable[T],它本身必须采用 CountVectorizerModel 或 IDFModel 作为类型。这是一种递归父 class 循环,我无法找到解决方案。
相比之下,通用模型编写器很容易,因为 MLWritable 是我感兴趣的所有模型扩展的共同特征:
def modelToBytes[M <: MLWritable](model: M): Array[Byte] = {
val tempPath: Path = KAZOO_TEMP_DIR.resolve(s"model_${System.currentTimeMillis()}")
model.write.overwrite().save(tempPath.toString)
Files.readAllBytes(tempPath)
}
如何制作一个通用的 reader 将 spark-ml 模型转换为字节数组?
要使其正常工作,您需要访问特定的 MlReadable
对象。
import org.apache.spark.ml.util.MLReadable
def modelFromBytes[M](obj: MLReadable[M], modelArray: Array[Byte]): M = {
val tempPath: Path = ???
...
obj.read.load(tempPath.toString)
}
以后可以用作:
val bytes: Array[Byte] = ???
modelFromBytes(CountVectorizerModel, bytes)
请注意,尽管第一次出现,但这里没有任何递归 - MLReadable[M]
指的是伴随对象,而不是 class 本身。因此,例如 CountVectorizerModel
object is MLReadable
, while CountVectorizeModel
class 不是。
在内部,Spark MLReader
以不同的方式处理此问题 - it creates an instance of the class using reflection, and then sets its Params
。然而这条路在这里对你不是很有用*。
如果需要与当前 API 兼容,您可以尝试使可读对象隐式化:
def modelFromBytes[M](modelArray: Array[Byte])(implicit obj: MLReadable[M]): M = {
...
}
然后是
implicit val readable: MLReadable[CountVectorizerModel] = CountVectorizerModel
modelFromBytes[CountVectorizerModel](bytes)
* 从技术上讲,可以通过反射获得伴生对象
def modelFromBytesCV[M <: MLWritable](
modelArray: Array[Byte])(implicit ct: ClassTag[M]): M = {
val tempPath: Path = ???
...
val cls = Class.forName(ct.runtimeClass.getName + "$");
cls.getField("MODULE$").get(cls).asInstanceOf[MLReadable[M]]
.read.load(tempPath.toString))
}
但我认为这不值得在这里探索。特别是我们在这里不能真正提供严格的类型界限——使用 MLWritable
是一种限制人为错误的技巧,但对编译器来说毫无用处。