plot_confusion_matrix() 使用 sklearn 得到了一个意外的关键字参数 'classes'
plot_confusion_matrix() got an unexpected keyword argument 'classes' using sklearn
我是 python 和深度学习的新手,我训练了一个多分类器模型并想绘制一个混淆矩阵,但我遇到了一个错误
这是我的代码
from sklearn.metrics import plot_confusion_matrix
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
Y_pred = model.predict_generator(test_generator)
y_pred = np.argmax(Y_pred, axis=1)
category_names = sorted(os.listdir('D:/DiabaticRetinopathy/mq_dataset/DR_Normal/train'))
print(category_names)
cm = confusion_matrix(test_generator.classes, y_pred)
plot_confusion_matrix(cm, classes = category_names, title='Confusion Matrix', normalize=False, figname = 'Confusion_matrix_concrete.jpg')
我将我的 sklearn 更新到了 0.24 版本。更新后我重新启动了内核,但仍然出现错误:
TypeError: plot_confusion_matrix() got an unexpected keyword argument 'classes'
错误指出您提供的关键字 类 不是此函数可识别的关键字。这发生在你的最后一行。
文档列出了您可以使用的关键字:
doc
使用labels
代替类,然后删除title, figname
plot_confusion_matrix(X = test_generator.classes, y_true = y_pred,labels= category_names, normalize=False)
文档:https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html
有一个关键字 labels,但没有 类,因此您可以将其更改为那个。
我是 python 和深度学习的新手,我训练了一个多分类器模型并想绘制一个混淆矩阵,但我遇到了一个错误 这是我的代码
from sklearn.metrics import plot_confusion_matrix
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
Y_pred = model.predict_generator(test_generator)
y_pred = np.argmax(Y_pred, axis=1)
category_names = sorted(os.listdir('D:/DiabaticRetinopathy/mq_dataset/DR_Normal/train'))
print(category_names)
cm = confusion_matrix(test_generator.classes, y_pred)
plot_confusion_matrix(cm, classes = category_names, title='Confusion Matrix', normalize=False, figname = 'Confusion_matrix_concrete.jpg')
我将我的 sklearn 更新到了 0.24 版本。更新后我重新启动了内核,但仍然出现错误:
TypeError: plot_confusion_matrix() got an unexpected keyword argument 'classes'
错误指出您提供的关键字 类 不是此函数可识别的关键字。这发生在你的最后一行。
文档列出了您可以使用的关键字: doc
使用labels
代替类,然后删除title, figname
plot_confusion_matrix(X = test_generator.classes, y_true = y_pred,labels= category_names, normalize=False)
文档:https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html
有一个关键字 labels,但没有 类,因此您可以将其更改为那个。