StringIndexer,其中类别级别作为列表传递

StringIndexer where category levels passed as list

StringIndexer 似乎是根据数据中的唯一值推断索引。当数据没有所有可能的值时,这是一个问题。下面的玩具示例考虑了三种 T 恤尺码(小号、中号和大号),但数据中只有两种(小号和大号)。我希望 StringIndexer 仍然考虑所有 3 种可能的大小。有什么方法可以使用提供的列表中的字符串索引创建列吗?最好将其作为 Transformer() 来执行,以便可以在管道中重复使用。

from pyspark.sql import Row
df = spark.createDataFrame([Row(id='0', size='Small'),
                            Row(id='1', size='Small'),
                            Row(id='2', size='Large')])
(
    StringIndexer(inputCol="size", outputCol="size_idx")
    .fit(df)
    .transform(df)
    .show()
)
+---+-----+--------+
| id| size|size_idx|
+---+-----+--------+
|  0|Small|     0.0|
|  1|Small|     0.0|
|  2|Large|     1.0|
+---+-----+--------+

期望输出

+---+-----+--------+
| id| size|size_idx|
+---+-----+--------+
|  0|Small|     0.0|
|  1|Small|     0.0|
|  2|Large|     2.0|
+---+-----+--------+

看起来您可以直接从一组标签创建 StringIndexer 模型,而不是从数据拟合。

from pyspark.sql import Row
from pyspark.ml.feature import StringIndexerModel

df = spark.createDataFrame([Row(id='0', size='Small'),
                            Row(id='1', size='Small'),
                            Row(id='2', size='Large')])

si = StringIndexerModel.from_labels(['Small', 'Medium', 'Large'],
                                    inputCol="size",
                                    outputCol="size_idx")

si.transform(df).show()
+---+-----+--------+
| id| size|size_idx|
+---+-----+--------+
|  0|Small|     0.0|
|  1|Small|     0.0|
|  2|Large|     2.0|
+---+-----+--------+