ConfusionMatrixDisplay(confusion_matrix).plot() 不显示任何内容

ConfusionMatrixDisplay(confusion_matrix).plot() doesn't show anything

我有以下代码:

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

...

confusion_matrix = confusion_matrix(validation_generator.classes, y_pred, normalize='all')

print(confusion_matrix)

display = ConfusionMatrixDisplay(confusion_matrix).plot()

输出:

[[0.013 0.487]
 [0.001 0.499]]

问题是执行.plot()时没有显示混淆矩阵图

我在我的 venv 中执行了 pip freeze > requirements.txt 这些是我 requirements.txt

中的包版本
absl-py==0.11.0
astunparse==1.6.3
autopep8==1.5.5
cachetools==4.2.1
certifi==2020.12.5
chardet==4.0.0
cycler==0.10.0
flatbuffers==1.12
gast==0.3.3
google-auth==1.27.1
google-auth-oauthlib==0.4.3
google-pasta==0.2.0
grpcio==1.32.0
h5py==2.10.0
idna==2.10
joblib==1.0.1
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
Markdown==3.3.4
matplotlib==3.2.0
numpy==1.19.5
oauthlib==3.1.0
opt-einsum==3.3.0
pandas==1.2.3
Pillow==8.1.2
protobuf==3.15.5
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycodestyle==2.6.0
pydot==1.4.2
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2021.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-learn==0.24.1
scipy==1.6.1
seaborn==0.11.1
six==1.15.0
sklearn==0.0
tensorboard==2.4.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.4.1
tensorflow-estimator==2.4.0
termcolor==1.1.0
threadpoolctl==2.1.0
toml==0.10.2
typing-extensions==3.7.4.3
urllib3==1.26.3
Werkzeug==1.0.1
wrapt==1.12.1

请加plt.show().

此处演示:

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

y_true = [2, 0, 2, 2, 0, 1]
y_pred = [0, 0, 2, 2, 0, 2]
confusion_matrix1  =confusion_matrix(y_true, y_pred)
display = ConfusionMatrixDisplay(confusion_matrix1).plot()
plt.show()