混淆矩阵原始数据不匹配

Confusion matrix raws are mismatched

我创建了一个可以正常工作的混淆矩阵,但它的原始数据似乎没有按应有的方式与标签相关联。

我有一些字符串列表,分为训练和测试部分:

 train + test:
 positive: 16 + 4 = 20
 negprivate:  53 + 14 = 67
 negstratified: 893 + 224 = 1117

混淆矩阵建立在测试数据上:

 [[  0  14   0]
 [  3 220   1]
 [  0   4   0]]

代码如下:

my_tags = ['negprivate', 'negstratified', 'positive']

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
    logging.info('plot_confusion_matrix')
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(my_tags))
    target_names = my_tags
    plt.xticks(tick_marks, target_names, rotation=45)
    plt.yticks(tick_marks, target_names)
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label') 
    plt.show()

def evaluate_prediction(target, predictions, taglist, title="Confusion matrix"):
    logging.info('Evaluate prediction')
    print('accuracy %s' % accuracy_score(target, predictions))
    cm = confusion_matrix(target, predictions)
    print('confusion matrix\n %s' % cm)
    print('(row=expected, col=predicted)')
    print 'rows: \n %s \n %s \n %s ' % (taglist[0], taglist[1], taglist[2])

    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plot_confusion_matrix(cm_normalized, title + ' Normalized')

...

test_targets, test_regressors = zip(
    *[(doc.tags[0], doc2vec_model.infer_vector(doc.words, steps=20)) for doc in alltest]) 
logreg = linear_model.LogisticRegression(n_jobs=1, C=1e5)
logreg = logreg.fit(train_regressors, train_targets)
evaluate_prediction(test_targets, logreg.predict(test_regressors), my_tags, title=str(doc2vec_model))

但关键是我实际上必须查看结果矩阵中的数字并更改 my_tags 的顺序,以便它们可以彼此一致。据我所知,这应该以某种自动方式进行。 我想知道在哪?

我认为这只是标签的排序顺序,即 np.unique(target) 的输出。

总是最好有整数 class 标签,一切似乎 运行 更顺畅一些。您可以使用 LabelEncoder 获取这些,即

from sklearn import preprocessing
my_tags = ['negprivate', 'negstratified', 'positive']
le = preprocessing.LabelEncoder()
new_tags = le.fit_transform(my_tags)

现在您将拥有 [0 1 2] 作为您的新标签。当您进行绘图时,您希望标签直观,​​因此您可以使用 inverse_transform 来获取标签,即

le.inverse_transform(0)

输出:

'negprivate'