从混淆矩阵数组中找到最佳混淆矩阵

Find best confusion matrix from array of confusion matrices

假设我有一个 numpy 从 k 折交叉验证中获得的混淆矩阵数组,

[array([[39,  4],
       [ 9,  6]], dtype=int64), array([[39,  4],
       [ 9,  5]], dtype=int64), array([[37,  6],
       [11,  3]], dtype=int64), array([[42,  1],
       [11,  3]], dtype=int64), array([[40,  3],
       [ 9,  5]], dtype=int64)]

我可以像这样找到这些混淆矩阵的平均值,

mean_conf_matrices = np.mean(conf_matrices_arr, axis=0)

// [[39.4  3.6]
    [ 9.8  4.4]]

但我想知道如何从混淆矩阵数组中获得最佳混淆矩阵。类似于如何从 GridSearchCV 中获取 best_score_。我的想法是获得 TN、TP、FN、FP 并评估每个以找到具有最高 TN 和 TP 以及最低 FN 和 FP 的混淆矩阵。有没有更直观的方法来实现这个?

编辑:我的更详细的方法是获取混淆矩阵的每个部分(TN、TP、FN、FP)并将其存储在一个数组中索引,即从所有混淆矩阵中收集所有 TN - TN[0] 将是从 confusion_matrix[0] 中获取的 TN,反之亦然。然后可以选择我们要关注的指标——减少的 FP 或 FN。假设我们想要减少 FN,那么将从 FN 数组中取出最低的 FN。然后将获得其索引并从混淆矩阵数组中获得相同的索引,并且该矩阵将被选为最佳矩阵,即 FN[4] 是最佳的,因此,选择 confusion_matrix[4] .我想知道是否有更直观的方法来实现这个,因为我的方法感觉很麻烦。

您需要选择一个分类指标,使您能够比较不同的分类器:例如 AUC(ROC 曲线下的面积)、精度、召回率、F1(结合了精度和召回率)...
请参阅 scikit-learn 中的 link,了解不同的可能性和实施方式。

我在下面编写了我的实现代码。如果您有更直观的方法可以实现相同的效果 objective,请随时 post 回答。

import math
import numpy as np

def get_tp_tn_fp_tn_from_confusion_matrix(cfmatrix):
    
    FN = cfmatrix.sum(axis=0) - np.diag(cfmatrix)
    FP = cfmatrix.sum(axis=1) - np.diag(cfmatrix)
    TN = np.diag(cfmatrix)
    TP = cfmatrix.sum() - (FP + FN + TN)
    
    return (TP[:1], TN[:1], FP[:1], FN[:1])

def get_best_confusion_matrix(cfmatrices, reduction_bias=None, debug=False):
    
    assert reduction_bias in ['FP', 'FN'], \
        f'{reduction_bias} is not a valid reduction bias. Select "FN" or "FP".'
    
    tps = list()
    tns = list()
    fps = list()
    fns = list()
    
    for cfmatrix in cfmatrices:
        
        TP, TN, FP, FN = get_tp_tn_fp_tn_from_confusion_matrix(cfmatrix)
        tps.append(int(math.ceil(TP)))
        tns.append(int(math.ceil(TN)))
        fps.append(int(math.ceil(FP)))
        fns.append(int(math.ceil(FN)))
    
    idx = 0
    best_cfmatrix = None
    
    if reduction_bias == 'FN':
        
        lowest_fn = min(fns)
        idx_fn = fns.index(lowest_fn)
        
        if debug: print('The chosen confusion matrix is:\n', cfmatrices[idx_fn])
        
        for cfmatrix in cfmatrices:
            
            _, _, chosen_FP, chosen_FN = get_tp_tn_fp_tn_from_confusion_matrix(cfmatrices[idx_fn])
            
            if (cfmatrix == cfmatrices[idx_fn]).all():
                if debug: print('Skipping the chosen confusion matrix...')
                continue
            else:
                _, _, FP, FN = get_tp_tn_fp_tn_from_confusion_matrix(cfmatrix)
                
                if FN == chosen_FN:
                    if FP < chosen_FP:
                        if debug:
                            print('Found duplicate confusion matrix. It is better than chosen confusion matrix.')
                            print('Chosen confusion matrix replaced.')
                        idx_fp = fps.index(FP)
                        best_cfmatrix = cfmatrices[idx_fp]
                        break
                    else:
                        if debug: print('Found duplicate confusion matrix. Chosen confusion matrix is better.')
                        best_cfmatrix = cfmatrices[idx_fn]
                else:
                    if debug: print('Searching for duplicate confusion matrices. None found.')
                    best_cfmatrix = cfmatrices[idx_fn]
        
    elif reduction_bias == 'FP':

        lowest_fp = min(fps)
        idx_fp = fps.index(lowest_fp)
        
        if debug: print('The chosen confusion matrix is:\n', cfmatrices[idx_fp])
        
        for cfmatrix in cfmatrices:    
            
            _, _, chosen_FP, chosen_FN = get_tp_tn_fp_tn_from_confusion_matrix(cfmatrices[idx_fp])
            
            if (cfmatrix == cfmatrices[idx_fp]).all():
                if debug: print('Skipping the chosen confusion matrix...')
                continue
            else:
                _, _, FP, FN = get_tp_tn_fp_tn_from_confusion_matrix(cfmatrix)
                
                if FP == chosen_FP:
                    if FN < chosen_FN:
                        if debug:
                            print('Found duplicate confusion matrix. It is better than chosen confusion matrix.')
                            print('Chosen confusion matrix replaced.')
                        idx_fn = fns.index(FN)
                        best_cfmatrix = cfmatrices[idx_fn]
                        break
                    else:
                        if debug: print('Found duplicate confusion matrix. Chosen confusion matrix is better.')
                        best_cfmatrix = cfmatrices[idx_fp]
                else:
                    if debug: print('Searching for duplicate confusion matrices. None found.')
                    best_cfmatrix = cfmatrices[idx_fp]
                        
    return best_cfmatrix