plot_confusion_matrix() 使用 sklearn 得到了一个意外的关键字参数 'classes'

plot_confusion_matrix() got an unexpected keyword argument 'classes' using sklearn

我是 python 和深度学习的新手,我训练了一个多分类器模型并想绘制一个混淆矩阵,但我遇到了一个错误 这是我的代码

from sklearn.metrics import plot_confusion_matrix
import matplotlib.pyplot as plt  
from sklearn.metrics import ConfusionMatrixDisplay
Y_pred = model.predict_generator(test_generator)
y_pred = np.argmax(Y_pred, axis=1)
category_names = sorted(os.listdir('D:/DiabaticRetinopathy/mq_dataset/DR_Normal/train'))
print(category_names)
cm = confusion_matrix(test_generator.classes, y_pred)
plot_confusion_matrix(cm, classes = category_names, title='Confusion Matrix', normalize=False, figname = 'Confusion_matrix_concrete.jpg')

我将我的 sklearn 更新到了 0.24 版本。更新后我重新启动了内核,但仍然出现错误:

TypeError: plot_confusion_matrix() got an unexpected keyword argument 'classes'

错误指出您提供的关键字 类 不是此函数可识别的关键字。这发生在你的最后一行。

文档列出了您可以使用的关键字: doc

使用labels代替类,然后删除title, figname

plot_confusion_matrix(X = test_generator.classes, y_true = y_pred,labels= category_names, normalize=False)

文档:https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html

有一个关键字 labels,但没有 类,因此您可以将其更改为那个。