ApacheSparkML StringIndexer 吃掉了我的专栏
ApacheSparkML StringIndexer eats my columns
将 StringIndexer 应用于包含以下列的 df_notnull(DataFrame 对象)时:
scala> df_notnull.printSchema
root
|-- L0_S22_F545: string (nullable = true)
|-- L0_S0_F0: double (nullable = true)
|-- L0_S0_F2: double (nullable = true)
|-- L0_S0_F4: double (nullable = true)
只剩下那些:
scala> indexed.printSchema
root
|-- L0_S22_F545: string (nullable = true)
|-- L0_S22_F545Index: double (nullable = true)
这是我的代码:
:paste
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
val indexer = new StringIndexer()
.setInputCol("L0_S22_F545")
.setOutputCol("L0_S22_F545Index")
val indexed = indexer.fit(df_notnull).transform(df_notnull)
indexed.printSchema
我想保留所有列,只添加一些新列。我做错了什么?
找到解决方案。实际上,转换器不应该单独使用,而应该与管道一起使用——这样列就被保留了:
import org.apache.spark.ml.Pipeline
val transformers = Array(
indexer,
encoder
)
var pipeline = new Pipeline().setStages(transformers).fit(df_notnull)
var transformed = pipeline.transform(df_notnull)
结果是这样的:
scala> transformed.show
+-----------+--------+--------+--------+----------------+--------------+
|L0_S22_F545|L0_S0_F0|L0_S0_F2|L0_S0_F4|L0_S22_F545Index|L0_S22_F545Vec|
+-----------+--------+--------+--------+----------------+--------------+
| NA| 0.03| -0.034| -0.197| 0.0|(13,[0],[1.0])|
| NA| 0.0| 0.0| 0.0| 0.0|(13,[0],[1.0])|
| NA| 0.088| 0.086| 0.003| 0.0|(13,[0],[1.0])|
| NA| -0.036| -0.064| 0.294| 0.0|(13,[0],[1.0])|
| NA| -0.055| -0.086| 0.294| 0.0|(13,[0],[1.0])|
| NA| 0.003| 0.019| 0.294| 0.0|(13,[0],[1.0])|
| NA| 0.0| 0.0| 0.0| 0.0|(13,[0],[1.0])|
| NA| 0.0| 0.0| 0.0| 0.0|(13,[0],[1.0])|
| NA| -0.016| -0.041| -0.179| 0.0|(13,[0],[1.0])|
| NA| 0.0| 0.0| 0.0| 0.0|(13,[0],[1.0])|
| NA| 0.016| 0.093| -0.015| 0.0|(13,[0],[1.0])|
| NA| -0.062| -0.153| -0.197| 0.0|(13,[0],[1.0])|
| NA| -0.075| -0.093| 0.367| 0.0|(13,[0],[1.0])|
| NA| -0.003| -0.093| -0.161| 0.0|(13,[0],[1.0])|
| NA| -0.016| -0.138| -0.197| 0.0|(13,[0],[1.0])|
| NA| 0.252| 0.25| 0.003| 0.0|(13,[0],[1.0])|
| NA| 0.0| 0.0| 0.0| 0.0|(13,[0],[1.0])|
| NA| -0.016| -0.041| 0.003| 0.0|(13,[0],[1.0])|
| NA| 0.0| 0.0| 0.0| 0.0|(13,[0],[1.0])|
| NA| 0.088| 0.033| 0.33| 0.0|(13,[0],[1.0])|
+-----------+--------+--------+--------+----------------+--------------+
only showing top 20 rows
将 StringIndexer 应用于包含以下列的 df_notnull(DataFrame 对象)时:
scala> df_notnull.printSchema
root
|-- L0_S22_F545: string (nullable = true)
|-- L0_S0_F0: double (nullable = true)
|-- L0_S0_F2: double (nullable = true)
|-- L0_S0_F4: double (nullable = true)
只剩下那些:
scala> indexed.printSchema
root
|-- L0_S22_F545: string (nullable = true)
|-- L0_S22_F545Index: double (nullable = true)
这是我的代码:
:paste
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
val indexer = new StringIndexer()
.setInputCol("L0_S22_F545")
.setOutputCol("L0_S22_F545Index")
val indexed = indexer.fit(df_notnull).transform(df_notnull)
indexed.printSchema
我想保留所有列,只添加一些新列。我做错了什么?
找到解决方案
import org.apache.spark.ml.Pipeline
val transformers = Array(
indexer,
encoder
)
var pipeline = new Pipeline().setStages(transformers).fit(df_notnull)
var transformed = pipeline.transform(df_notnull)
结果是这样的:
scala> transformed.show
+-----------+--------+--------+--------+----------------+--------------+
|L0_S22_F545|L0_S0_F0|L0_S0_F2|L0_S0_F4|L0_S22_F545Index|L0_S22_F545Vec|
+-----------+--------+--------+--------+----------------+--------------+
| NA| 0.03| -0.034| -0.197| 0.0|(13,[0],[1.0])|
| NA| 0.0| 0.0| 0.0| 0.0|(13,[0],[1.0])|
| NA| 0.088| 0.086| 0.003| 0.0|(13,[0],[1.0])|
| NA| -0.036| -0.064| 0.294| 0.0|(13,[0],[1.0])|
| NA| -0.055| -0.086| 0.294| 0.0|(13,[0],[1.0])|
| NA| 0.003| 0.019| 0.294| 0.0|(13,[0],[1.0])|
| NA| 0.0| 0.0| 0.0| 0.0|(13,[0],[1.0])|
| NA| 0.0| 0.0| 0.0| 0.0|(13,[0],[1.0])|
| NA| -0.016| -0.041| -0.179| 0.0|(13,[0],[1.0])|
| NA| 0.0| 0.0| 0.0| 0.0|(13,[0],[1.0])|
| NA| 0.016| 0.093| -0.015| 0.0|(13,[0],[1.0])|
| NA| -0.062| -0.153| -0.197| 0.0|(13,[0],[1.0])|
| NA| -0.075| -0.093| 0.367| 0.0|(13,[0],[1.0])|
| NA| -0.003| -0.093| -0.161| 0.0|(13,[0],[1.0])|
| NA| -0.016| -0.138| -0.197| 0.0|(13,[0],[1.0])|
| NA| 0.252| 0.25| 0.003| 0.0|(13,[0],[1.0])|
| NA| 0.0| 0.0| 0.0| 0.0|(13,[0],[1.0])|
| NA| -0.016| -0.041| 0.003| 0.0|(13,[0],[1.0])|
| NA| 0.0| 0.0| 0.0| 0.0|(13,[0],[1.0])|
| NA| 0.088| 0.033| 0.33| 0.0|(13,[0],[1.0])|
+-----------+--------+--------+--------+----------------+--------------+
only showing top 20 rows