使用 matshow 绘图时混淆矩阵不完整

Incomplete confusion matrix when plotting with matshow

我正在尝试绘制这个混淆矩阵:

[[25940  2141    84    19     3     0     0     1   184     4]
 [ 3525  6357   322    41     5     1     3     0   242     2]
 [  410  1484  1021    80     5     6     0     0   282     0]
 [   98   285   189   334     9     9     5     1   140     0]
 [   26    64    55    50   112    15     4     1    75     0]
 [   11    45    20    24     5   118     8     0    79     0]
 [    1     8     8     5     0    10    62     1    55     0]
 [    2     0     0     0     0     0     2     0     6     0]
 [  510   524   103    55     5     7     7     1 65350     0]
 [   62    13     2     1     0     0     1     0    11    13]]

因此,10x10。这 10 个标签是:

[ 5  6  7  8  9 10 11 12 14 15]

我使用以下代码:

获取混淆矩阵

cm = confusion_matrix(y_test, y_pred, labels=labels)
print('Confusion Matrix of {} is:\n{}'.format(clf_name, cm))
print(labels)
plt.matshow(cm, interpolation='nearest')
ax = plt.gca()
ax.set_xticklabels([''] + labels.astype(str).tolist())
ax.set_yticklabels([''] + labels.astype(str).tolist())
plt.title('Confusion matrix of the {} classifier'.format(clf_name))
plt.colorbar(mat, extend='both')
plt.clim(0, 100)

而且我只得到一个标签从 5 到 9 的图:

这里有什么问题?

相关导入和配置(顺便说一句,我正在使用 Jupyter):

import matplotlib.pyplot as plt
import matplotlib as mpl
%matplotlib inline
plt.style.use('seaborn')
mpl.rcParams['figure.figsize'] = 8, 6

我尝试降级到 matplotlib 3.1.0,因为我读到关于 seaborn 的 3.1.1 出了点问题,但无论如何结果是一样的(如果我将样式更改为 ggplot)。

Matplotlib 不会在每个刻度处都放置一个标签(以防止刻度重叠,以防它们更长)。您可以使用 ax.set_xticks(range(10)).

在每一列强制刻度

这里是一些示例代码,调用适应了 matplotlib 的 "object oriented" 接口。此外,一些额外的填充可以防止标题与顶部刻度标签跳动。请注意,标签可以是数字,matplotlib 会自动将它们解释为相应的字符串。 ax.tick_params() 可以帮助删除底部和顶部的刻度线(或者,也可以将它们左 and/or 右)。示例代码还在次要 xticks 上使用网格来进行分隔。

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import numpy as np

cm = np.random.randint(0, 25000, (10, 10)) * np.random.randint(0, 2, (10, 10))
labels = np.array([5, 6, 7, 8, 9, 10, 11, 12, 14, 15])

fig, ax = plt.subplots()
mat = ax.matshow(cm, interpolation='nearest')
mat.set_clim(0, 100)
ax.set_xticks(range(10))
ax.set_yticks(range(10))
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
ax.tick_params(axis='x', which='both', bottom=False, top=False)

ax.grid(b=False, which='major', axis='both')
ax.xaxis.set_minor_locator(MultipleLocator(0.5))
ax.yaxis.set_minor_locator(MultipleLocator(0.5))
ax.grid(b=True, which='minor', axis='both', lw=2, color='white')

ax.set_title('Confusion matrix of the {} classifier'.format('clf_name'), pad=20)
plt.colorbar(mat, extend='both')
plt.show()