CNN:为整个测试数据集生成混淆矩阵

CNN: Generate a confusion matrix for entire test dataset

我正在使用以下代码来预测我的模型在数据集上的输出。

correct = 0
total_predictions = []
actual_labels = []
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
  
        _, predicted = torch.max(outputs.data, 1)
        actual_labels.append(labels)
        total_predictions.append(final_pred)
        final_pred = torch.FloatTensor(final_pred).to(device)
        correct += (predicted == labels).sum().item()

现在为了生成整个数据集的混淆矩阵,我尝试将我的预测和测试标签存储在列表中并将其传递给 sklearn 中的 confusion_matrix,但失败并出现以下错误:

ValueError: You appear to be using a legacy multi-label data representation. Sequence of sequences are no longer supported; use a binary array or sparse matrix instead.

谁能帮我计算整个数据集的混淆矩阵?

以下代码只计算最后一批:

 cf = confusion_matrix(predicted.cpu(), labels.cpu())

Update-1

使用@CutePoison 的模板,我明白了。

您似乎在使用旧版多标签数据表示法。不再支持序列序列;改为使用二进制数组或稀疏矩阵 - MultiLabelBinarizer 转换器可以转换为这种格式。

labels={}
labels['healthy_wheat'] = 0
labels['leaf_rust'] = 1
labels['stem_rust'] = 2

def conf_mat(y_true,y_pred,columns,**kwargs):
    conf_mat = confusion_matrix(y_true,y_pred,labels = columns,**kwargs)
    df = pd.DataFrame(conf_mat,columns = columns, index = columns)
    df.columns.name="pred"
    df.index.name="true"
    return df

conf_mat(actual_labels,total_predictions ,columns =labels,normalize="true")

我使用此代码段创建混淆矩阵,适用于多个 classes

from sklearn.metrics import confusion_matrix

def conf_mat(y_true,y_pred,columns,**kwargs):
    """
    Creates a "pretty" confusion matrix
    """

    conf_mat = confusion_matrix(y_true,y_pred,labels = columns,**kwargs)
    df = pd.DataFrame(conf_mat,columns = columns, index = columns)
    df.columns.name="pred"
    df.index.name="true"
    return df



conf_mat(actual_labels,final_pred ,columns =np.unique(actual_labels),normalize="true")

请注意,您可能希望根据标签的创建方式更改 columns

此外,您的 final_pred 必须包含您的 class 预测而不是分数,即 final_pred = [0,1,2,0...] 而不是 final_pred= [[0.8,0.1,0.1], [0.1,0.7,0.2],[0.05,0.05,0.9],[0.75,0.2,0.05],...]