ConfusionMatrixDisplay (Scikit-Learn) 绘图标签超出范围
ConfusionMatrixDisplay (Scikit-Learn) plot labels out of range
以下代码绘制了一个混淆矩阵:
from sklearn.metrics import ConfusionMatrixDisplay
confusion_matrix = confusion_matrix(y_true, y_pred)
target_names = ["aaaaa", "bbbbbb", "ccccccc", "dddddddd", "eeeeeeeeee", "ffffffff", "ggggggggg"]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=target_names)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
plt.savefig("conf.png")
这个情节有两个问题。
- y轴标签被截断(True Label)。 x标签也被切断了。
- x 轴的名称太长了。
为了解决第一个问题,我尝试使用 poof(bbox_inches='tight')
,不幸的是,它不适用于 sklearn。
在第二种情况下,我为 尝试了以下解决方案,这导致了一个完全扭曲的情节。
总而言之,我正在为这两个问题而苦苦挣扎。
我认为最简单的方法是切换到 tight_layout
并添加 pad_inches=
内容。
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from numpy.random import default_rng
rand = default_rng()
y_true = rand.integers(low=0, high=7, size=500)
y_pred = rand.integers(low=0, high=7, size=500)
confusion_matrix = confusion_matrix(y_true, y_pred)
target_names = ["aaaaa", "bbbbbb", "ccccccc", "dddddddd", "eeeeeeeeee", "ffffffff", "ggggggggg"]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=target_names)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
plt.tight_layout()
plt.savefig("conf.png", pad_inches=5)
结果:
以下代码绘制了一个混淆矩阵:
from sklearn.metrics import ConfusionMatrixDisplay
confusion_matrix = confusion_matrix(y_true, y_pred)
target_names = ["aaaaa", "bbbbbb", "ccccccc", "dddddddd", "eeeeeeeeee", "ffffffff", "ggggggggg"]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=target_names)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
plt.savefig("conf.png")
这个情节有两个问题。
- y轴标签被截断(True Label)。 x标签也被切断了。
- x 轴的名称太长了。
为了解决第一个问题,我尝试使用 poof(bbox_inches='tight')
,不幸的是,它不适用于 sklearn。
在第二种情况下,我为
总而言之,我正在为这两个问题而苦苦挣扎。
我认为最简单的方法是切换到 tight_layout
并添加 pad_inches=
内容。
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from numpy.random import default_rng
rand = default_rng()
y_true = rand.integers(low=0, high=7, size=500)
y_pred = rand.integers(low=0, high=7, size=500)
confusion_matrix = confusion_matrix(y_true, y_pred)
target_names = ["aaaaa", "bbbbbb", "ccccccc", "dddddddd", "eeeeeeeeee", "ffffffff", "ggggggggg"]
disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix, display_labels=target_names)
disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
plt.tight_layout()
plt.savefig("conf.png", pad_inches=5)
结果: