混淆矩阵与 dask

Confusion matrix with dask

我正在尝试使用 Dask 计算混淆矩阵元素。 从算法的角度来看我的实现似乎还可以。 但是,当我 运行 它在 2 个大小为 100 万的数组上时,它需要很长时间。

有人对如何优化此代码有建议吗?

def confusion_matrix_dask(truth,predictions,labels_list=[]):
    TP=0
    FP=0
    FN=0
    TN=0
    if not labels_list:
        TP=(truth[predictions==1]==1).sum()
        FP=(truth[predictions!=1]==1).sum()
        TN=(truth[predictions!=1]!=1).sum()
        FN=(truth[predictions==1]!=1).sum()
    for label in labels_list:
        TP=(truth[predictions==label]==label).sum()+TP
        FP=(truth[predictions!=label]==label).sum()+FP
        TN=(truth[predictions!=label]!=label).sum()+TN
        FN=(truth[predictions==label]!=label).sum()+FN


    return np.array([[TN.compute(), FP.compute()] , [TN.compute() ,FN.compute()]])

您应该注意的一个快速改进:

import dask
TP, FP, TN, FN = dask.compute(TP, FP, TN, FN)

而不是对每个调用 .compute()。这将共享共同的数据和任务,从而减少要完成的总工作量。