管道中 Spark Dataframe 中的 OneHotEncoder

OneHotEncoder in Spark Dataframe in Pipeline

我一直在尝试使用 adult dataset 在 Spark 和 Scala 中获取示例 运行。

使用 Scala 2.11.8 和 Spark 1.6.1。

问题(目前)在于该数据集中的分类特征数量,在 Spark ML 算法完成其工作之前,所有这些特征都需要编码成数字。

到目前为止我有这个:

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.OneHotEncoder
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}

object Adult {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("Adult example").setMaster("local[*]")
    val sparkContext = new SparkContext(conf)
    val sqlContext = new SQLContext(sparkContext)

    val data = sqlContext.read
      .format("com.databricks.spark.csv")
      .option("header", "true") // Use first line of all files as header
      .option("inferSchema", "true") // Automatically infer data types
      .load("src/main/resources/adult.data")

    val categoricals = data.dtypes filter (_._2 == "StringType")
    val encoders = categoricals map (cat => new OneHotEncoder().setInputCol(cat._1).setOutputCol(cat._1 + "_encoded"))
    val features = data.dtypes filterNot (_._1 == "label") map (tuple => if(tuple._2 == "StringType") tuple._1 + "_encoded" else tuple._1)

    val lr = new LogisticRegression()
      .setMaxIter(10)
      .setRegParam(0.01)
    val pipeline = new Pipeline()
      .setStages(encoders ++ Array(lr))

    val model = pipeline.fit(training)
  }
}

但是,这不起作用。调用 pipeline.fit 仍然包含原始字符串特征,因此会抛出异常。 如何删除管道中的这些 "StringType" 列? 或者也许我做的完全错了,所以如果有人有不同的建议,我很乐意接受所有意见:)。

我选择遵循这个流程的原因是因为我在 Python 和 Pandas 方面有广泛的背景,但我正在尝试学习 Scala 和 Spark。

如果您习惯了更高级别的框架,这里有一件事可能会让您感到困惑。在使用编码器之前,您必须索引这些功能。正如 the API docs 中所解释的:

one-hot encoder (...) maps a column of category indices to a column of binary vectors, with at most a single one-value per row that indicates the input category index.

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoder}

val df = Seq((1L, "foo"), (2L, "bar")).toDF("id", "x")

val categoricals = df.dtypes.filter (_._2 == "StringType") map (_._1)

val indexers = categoricals.map (
  c => new StringIndexer().setInputCol(c).setOutputCol(s"${c}_idx")
)

val encoders = categoricals.map (
  c => new OneHotEncoder().setInputCol(s"${c}_idx").setOutputCol(s"${c}_enc")
)

val pipeline = new Pipeline().setStages(indexers ++ encoders)

val transformed = pipeline.fit(df).transform(df)
transformed.show

// +---+---+-----+-------------+
// | id|  x|x_idx|        x_enc|
// +---+---+-----+-------------+
// |  1|foo|  1.0|    (1,[],[])|
// |  2|bar|  0.0|(1,[0],[1.0])|
// +---+---+-----+-------------+

如您所见,无需从管道中删除字符串列。实际上 OneHotEncoder 将接受具有 NominalAttributeBinaryAttribute 或缺少类型属性的数字列。