SparkML 交叉验证是否仅适用于 "label" 列?

Does SparkML Cross Validation Only Work With a "label" Column?

当我运行与一个数据集进行交叉验证example时,该数据集的标签列not名为“label”,我正在观察Spark 3.1.1 上的 IllegalArgumentException。为什么?

已修改以下代码,将“label”列重命名为“target”,并将回归模型的 labelCol 设置为“target”。此代码导致异常,而将所有内容保留在“标签”处工作正常。

from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.feature import HashingTF, Tokenizer
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

training = spark.createDataFrame([
    (0, "a b c d e spark", 1.0),
    (1, "b d", 0.0),
    (2, "spark f g h", 1.0),
    (3, "hadoop mapreduce", 0.0),
    (4, "b spark who", 1.0),
    (5, "g d a y", 0.0),
    (6, "spark fly", 1.0),
    (7, "was mapreduce", 0.0),
    (8, "e spark program", 1.0),
    (9, "a e c l", 0.0),
    (10, "spark compile", 1.0),
    (11, "hadoop software", 0.0)
], ["id", "text", "target"]) # try switching between "target" and "label"

tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")

lr = LogisticRegression(maxIter=10, labelCol="target") #try switching between "target" and "label"

pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])

paramGrid = ParamGridBuilder() \
    .addGrid(hashingTF.numFeatures, [10, 100, 1000]) \
    .addGrid(lr.regParam, [0.1, 0.01]) \
    .build()

crossval = CrossValidator(estimator=pipeline,
                          estimatorParamMaps=paramGrid,
                          evaluator=BinaryClassificationEvaluator(),
                          numFolds=2)  


cvModel = crossval.fit(training)

这是预期的行为吗?

您还需要向 BinaryClassificationEvaluator 提供标签列。所以如果你替换行

evaluator=BinaryClassificationEvaluator(),

evaluator=BinaryClassificationEvaluator(labelCol="target"),

它应该可以正常工作。

您可以在 docs 中找到用法。