spark OneHotEncoder - 如何排除用户定义的类别?

spark OneHotEncoder - how to exclude user-defined category?

考虑以下 spark 数据框:

df.printSchema()

     |-- predictor: double (nullable = true)
     |-- label: double (nullable = true)
     |-- date: string (nullable = true)

df.show(6)

    predictor      label              date    
    4.23           6.33               20160510
    4.77           7.18               20160510
    4.09           5.94               20160511
    4.23           6.33               20160511
    4.77           7.18               20160512
    4.09           5.94               20160512

本质上,我的数据框由每日频率的数据组成。我需要将日期列映射到二进制向量列。使用 StringIndexer 和 OneHotEncoder 很容易实现:

val dateIndexer = new StringIndexer()
  .setInputCol("date")
  .setOutputCol("dateIndex")
  .fit(df)
val indexed = dateIndexer.transform(df)

val encoder = new OneHotEncoder()
  .setInputCol("dateIndex")
  .setOutputCol("date_codeVec")

val encoded = encoder.transform(indexed)

我的问题是OneHotEncoder drops the last category by default。但是,我需要删除与数据框中的第一个日期相关的类别(上例中为 20160510),因为我需要计算相对于第一个日期的时间趋势。

对于上面的示例,我如何实现这一点(请注意,我的数据框中有超过 100 个日期)?

您可以尝试将 setDropLast 设置为 false:

val encoder = new OneHotEncoder()
  .setInputCol("dateIndex")
  .setOutputCol("date_codeVec")
  .setDropLast(false)

val encoded = encoder.transform(indexed)

并手动删除关卡选择,使用 VectorSlicer:

import org.apache.spark.ml.feature.VectorSlicer

val slicer = new VectorSlicer()
  .setInputCol("date_codeVec")
  .setOutputCol("data_codeVec_selected")
  .setNames(dateIndexer.labels.diff(Seq(dateIndexer.labels.min)))

slicer.transform(encoded)
+---------+-----+--------+---------+-------------+---------------------+
|predictor|label|    date|dateIndex| date_codeVec|data_codeVec_selected|
+---------+-----+--------+---------+-------------+---------------------+
|     4.23| 6.33|20160510|      0.0|(3,[0],[1.0])|            (2,[],[])|
|     4.77| 7.18|20160510|      0.0|(3,[0],[1.0])|            (2,[],[])|
|     4.09| 5.94|20160511|      2.0|(3,[2],[1.0])|        (2,[1],[1.0])|
|     4.23| 6.33|20160511|      2.0|(3,[2],[1.0])|        (2,[1],[1.0])|
|     4.77| 7.18|20160512|      1.0|(3,[1],[1.0])|        (2,[0],[1.0])|
|     4.09| 5.94|20160512|      1.0|(3,[1],[1.0])|        (2,[0],[1.0])|
+---------+-----+--------+---------+-------------+---------------------+