混淆矩阵无法显示所有标签
Confusion matrix fails to show all labels
我已经为多标签 class 化创建了一个 class 化矩阵来评估 MLPClassifier 模型的性能。混淆矩阵输出应该是 10x10,但有时我得到 8x8,因为它不显示 1 或 2 class 标签的标签值,正如您在我 [=18= 时从代码下方的混淆矩阵热图中看到的那样] 整个Jupyter笔记本。真实标签和预测标签的 class 标签是从 1 到 10(无序)。是因为代码错误还是仅取决于测试数据集在将数据拆分为训练集和测试集时接受的随机输入样本?我应该如何解决这个问题?代码的实现如下所示:
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred)
print(cm)
str(cm)
Out: [[20 0 0 1 0 5 1 0]
[ 3 0 0 0 0 0 0 0]
[ 1 1 0 1 0 1 0 0]
[ 3 0 0 0 0 3 1 1]
[ 0 0 0 0 0 1 0 0]
[ 3 0 0 1 0 2 1 1]
[ 3 0 0 0 0 0 0 2]
[ 1 0 0 0 0 0 0 1]]
'[[20 0 0 1 0 5 1 0]\n [ 3 0 0 0 0 0 0 0]\n [ 1 1 0 1 0
1 0 0]\n [ 3 0 0 0 0 3 1 1]\n [ 0 0 0 0 0 1 0 0]\n [ 3 0
0 1 0 2 1 1]\n [ 3 0 0 0 0 0 0 2]\n [ 1 0 0 0 0 0 0
1]]'
import matplotlib.pyplot as plt
import seaborn as sns
side_bar = [1,2,3,4,5,6,7,8,9,10]
f, ax = plt.subplots(figsize=(12,12))
sns.heatmap(cm, annot=True, linewidth=.5, linecolor="r", fmt=".0f", ax = ax)
ax.set_xticklabels(side_bar)
ax.set_yticklabels(side_bar)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()
confusion matrix heatmap
我觉得这里有点混乱!混淆矩阵是 set(y_test) + set(y_pred)。所以如果它变成 8。那么矩阵将是 8 X 8。现在,如果你有更多的标签你想打印,即使它是否全为零也没关系。然后在构建矩阵时,您需要输入“labels”参数。
y_true = [2, 0, 2, 2, 0, 1,5]
y_pred = [0, 0, 2, 2, 0, 2,4]
confusion_matrix = confusion_matrix(y_true, y_pred,labels=[0,1,2,3,4,5,6])
如您所见,y_true 或 y_pred 中确实没有 6,您将为它清零。
我已经为多标签 class 化创建了一个 class 化矩阵来评估 MLPClassifier 模型的性能。混淆矩阵输出应该是 10x10,但有时我得到 8x8,因为它不显示 1 或 2 class 标签的标签值,正如您在我 [=18= 时从代码下方的混淆矩阵热图中看到的那样] 整个Jupyter笔记本。真实标签和预测标签的 class 标签是从 1 到 10(无序)。是因为代码错误还是仅取决于测试数据集在将数据拆分为训练集和测试集时接受的随机输入样本?我应该如何解决这个问题?代码的实现如下所示:
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred)
print(cm)
str(cm)
Out: [[20 0 0 1 0 5 1 0]
[ 3 0 0 0 0 0 0 0]
[ 1 1 0 1 0 1 0 0]
[ 3 0 0 0 0 3 1 1]
[ 0 0 0 0 0 1 0 0]
[ 3 0 0 1 0 2 1 1]
[ 3 0 0 0 0 0 0 2]
[ 1 0 0 0 0 0 0 1]]
'[[20 0 0 1 0 5 1 0]\n [ 3 0 0 0 0 0 0 0]\n [ 1 1 0 1 0
1 0 0]\n [ 3 0 0 0 0 3 1 1]\n [ 0 0 0 0 0 1 0 0]\n [ 3 0
0 1 0 2 1 1]\n [ 3 0 0 0 0 0 0 2]\n [ 1 0 0 0 0 0 0
1]]'
import matplotlib.pyplot as plt
import seaborn as sns
side_bar = [1,2,3,4,5,6,7,8,9,10]
f, ax = plt.subplots(figsize=(12,12))
sns.heatmap(cm, annot=True, linewidth=.5, linecolor="r", fmt=".0f", ax = ax)
ax.set_xticklabels(side_bar)
ax.set_yticklabels(side_bar)
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.show()
confusion matrix heatmap
我觉得这里有点混乱!混淆矩阵是 set(y_test) + set(y_pred)。所以如果它变成 8。那么矩阵将是 8 X 8。现在,如果你有更多的标签你想打印,即使它是否全为零也没关系。然后在构建矩阵时,您需要输入“labels”参数。
y_true = [2, 0, 2, 2, 0, 1,5]
y_pred = [0, 0, 2, 2, 0, 2,4]
confusion_matrix = confusion_matrix(y_true, y_pred,labels=[0,1,2,3,4,5,6])
如您所见,y_true 或 y_pred 中确实没有 6,您将为它清零。