ValueError: multiclass format is not supported on ROC_Curve for text classification

ValueError: multiclass format is not supported on ROC_Curve for text classification

我正在尝试使用 ROC 来评估我的情感文本分类器模型

这是我的 ROC 代码:

# ROC-AUC Curve
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
fpr, tpr, thresholds = roc_curve(y_test, y_test_hat2)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=1, label='ROC curve (area = %0.2f)' % roc_auc)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC CURVE')
plt.legend(loc="lower right")
plt.show()

这是错误:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-30-ef4ee0eff994> in <module>()
      2 from sklearn.metrics import roc_curve, auc
      3 import matplotlib.pyplot as plt
----> 4 fpr, tpr, thresholds = roc_curve(y_test, y_test_hat2)
      5 roc_auc = auc(fpr, tpr)
      6 plt.figure()

1 frames
/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_ranking.py in roc_curve(y_true, y_score, pos_label, sample_weight, drop_intermediate)
    961     """
    962     fps, tps, thresholds = _binary_clf_curve(
--> 963         y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
    964     )
    965 

/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_ranking.py in _binary_clf_curve(y_true, y_score, pos_label, sample_weight)
    729     y_type = type_of_target(y_true)
    730     if not (y_type == "binary" or (y_type == "multiclass" and pos_label is not None)):
--> 731         raise ValueError("{0} format is not supported".format(y_type))
    732 
    733     check_consistent_length(y_true, y_score, sample_weight)

ValueError: multiclass format is not supported

这是 y_test 和 y_test_hat2 :

y_test = data_test["label"]


from sklearn.feature_extraction.text import TfidfVectorizer

vectorizer = TfidfVectorizer()
test_vectors = vectorizer.transform(data_test['tweet'])
classifier_linear2 = LinearSVC(verbose=1)
y_test_hat2=classifier_linear2.predict(test_vectors)

test_vectors 的形状 = (1096, 11330)

y_test 的形状 = (1096,)

y_test 中的标签 = ['0', '1', '2', '3', '4']

ROC曲线基于软预测,即它使用预测的概率实例属于正class 而不是预测的 class。例如,使用 sklearn 可以使用 predict_proba 而不是 predict 获得概率(对于提供它的 classifiers,example)。

注意:OP 使用了标签 multiclass-classification,但需要注意的是,ROC 曲线只能应用于 binary classification 问题。

可以找到 ROC 曲线的简短解释 here