包括持久性的 Spark 自定义估算器

Spark custom estimator including persistence

我想为 spark 开发一个自定义估算器,它也可以处理出色管道 API 的持久性。但正如 How to Roll a Custom Estimator in PySpark mllib 所说,目前还没有很多文档。

我有一些用 spark 编写的数据清理代码,想将其包装在自定义估算器中。包括一些 na 替换、列删除、过滤和基本特征生成(例如生日到年龄)。

我还不太清楚的是:

如果你能帮我做一个自定义估算器就太好了——尤其是持久性部分。

首先,我认为您将两种不同的东西混为一谈:

  • Estimators - 代表可以 fit-ted 的阶段。 Estimator fit 方法采用 Dataset 和 returns Transformer(模型)。
  • Transformers - 表示可以 transform 数据的阶段。

当你 fit Pipelinefits 所有 Estimators 和 returns PipelineModelPipelineModel 可以 transform 数据顺序调用 transform 模型中的所有 Transformers

how should I transfer the fitted values

这个问题没有单一的答案。一般来说,你有两个选择:

  • 将拟合模型的参数作为 Transformer 的参数传递。
  • 使 Transformer 的拟合模型参数 Params

第一种方法通常由内置 Transformer 使用,但第二种方法应该适用于一些简单的情况。

how to handle persistence

  • 如果 Transformer 仅由其 Params 定义,您可以扩展 DefaultParamsReadable.
  • 如果您使用更复杂的参数,您应该扩展 MLWritable 并实现对您的数据有意义的 MLWriter。 Spark源码中有多个示例展示了如何实现数据和元数据的读写。

如果您正在寻找一个易于理解的示例,请查看 CountVectorizer(Model) 其中:

以下使用 Scala API,但如果您真的想要...

,您可以轻松地将其重构为 Python

要事第一:

  • Estimator:实现 .fit() returns 一个 Transformer
  • Transformer:实现.transform()并操作DataFrame
  • Serialization/Deserialization:尽力使用内置参数并利用简单的 DefaultParamsWritable trait + 伴随对象 扩展 DefaultParamsReadable[T]。 a.k.a 远离 MLReader / MLWriter 并保持代码简单。
  • 参数传递:使用扩展Params的共同特征并在您的估算器和模型(a.k.a.Transformer)之间共享它

骨架代码:

// Common Parameters
trait MyCommonParams extends Params {
  final val inputCols: StringArrayParam = // usage: new MyMeanValueStuff().setInputCols(...)
    new StringArrayParam(this, "inputCols", "doc...")
    def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    def getInputCols: Array[String] = $(inputCols)

  final val meanValues: DoubleArrayParam = 
    new DoubleArrayParam(this, "meanValues", "doc...")
    // more setters and getters
}

// Estimator
class MyMeanValueStuff(override val uid: String) extends Estimator[MyMeanValueStuffModel] 
  with DefaultParamsWritable // Enables Serialization of MyCommonParams
  with MyCommonParams {

  override def copy(extra: ParamMap): Estimator[MeanValueFillerModel] = defaultCopy(extra) // deafult
  override def transformSchema(schema: StructType): StructType = schema // no changes
  override def fit(dataset: Dataset[_]): MyMeanValueStuffModel = {
    // your logic here. I can't do all the work for you! ;)
   this.setMeanValues(meanValues)
   copyValues(new MyMeanValueStuffModel(uid + "_model").setParent(this))
  }
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuff extends DefaultParamsReadable[MyMeanValueStuff]

// Model (Transformer)
class MyMeanValueStuffModel(override val uid: String) extends Model[MyMeanValueStuffModel] 
  with DefaultParamsWritable // Enables Serialization of MyCommonParams
  with MyCommonParams {

  override def copy(extra: ParamMap): MyMeanValueStuffModel = defaultCopy(extra) // default
  override def transformSchema(schema: StructType): StructType = schema // no changes
  override def transform(dataset: Dataset[_]): DataFrame = {
      // your logic here: zip inputCols and meanValues, toMap, replace nulls with NA functions
      // you have access to both inputCols and meanValues here!
  }
}
// Companion object enables deserialization of MyCommonParams
object MyMeanValueStuffModel extends DefaultParamsReadable[MyMeanValueStuffModel]

使用上面的代码,您可以 Serialize/Deserialize 一个包含 MyMeanValueStuff 阶段的管道。

想看看 Estimator 的一些真正简单的实现吗? MinMaxScaler! (虽然我的例子实际上更简单......)