AUC 曲线甚至没有显示在图上

AUC curve didn't even show on the plot

使用逻辑回归预测东西后,这是我得到的混淆矩阵:

True Positives: 3
False Positives: 1309
True Negatives: 12361
False Negatives: 4

roc_auc_score 在这里:

roc_auc_score(y_test, log_preds)
0.6664071480823492

所以我想使用这段代码将其可视化:

probas = lg.predict_proba(X_test)[:, 1]
def get_preds(threshold, probabilities):
    return [1 if prob > threshold else 0 for prob in probabilities]
roc_values = []
for thresh in np.linspace(0, 1, 100):
    preds = get_preds(thresh, probas)
    tn, fp, fn, tp = confusion_matrix(y_test, log_preds).ravel()
    tpr = tp/(tp+fn)
    fpr = fp/(fp+tn)
    roc_values.append([tpr, fpr])
tpr_values, fpr_values = zip(*roc_values)
fig, ax = plt.subplots(figsize=(10,7))
ax.plot(fpr_values, tpr_values)
ax.plot(np.linspace(0, 1, 100),
         np.linspace(0, 1, 100),
         label='baseline',
         linestyle='--')
plt.title('Receiver Operating Characteristic Curve', fontsize=18)
plt.ylabel('TPR', fontsize=16)
plt.xlabel('FPR', fontsize=16)
plt.legend(fontsize=12);

下面是输出,只有一个baseline,没看懂。 (我的声望还不够,无法嵌入图片,请随时编辑。谢谢!)

This is the output of ROAUC plot

好的,现在我有点明白是怎么回事了。

我写下这行代码看看发生了什么:

print(tpr_values)
print(fpr_values)

输出:

(0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855, 0.42857142857142855)
(0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006, 0.09575713240673006)

所有值都相同。所以他们都集中在同一个点上。

我解决了我自己的问题。现在输出正常了:

错误在这里:

roc_values = []
for thresh in np.linspace(0, 1, 100):
    preds = get_preds(thresh, probas)
    tn, fp, fn, tp = confusion_matrix(y_test, log_preds).ravel()
    tpr = tp/(tp+fn)
    fpr = fp/(fp+tn)
    roc_values.append([tpr, fpr])
tpr_values, fpr_values = zip(*roc_values)

修复 log_preds 后,它看起来像这样:

roc_values = []
for thresh in np.linspace(0, 1, 100):
    preds = get_preds(thresh, probas)       #~~~~~~~~
    tn, fp, fn, tp = confusion_matrix(y_test, preds).ravel()
    tpr = tp/(tp+fn)                        #^^^^^^^^
    fpr = fp/(fp+tn)
    roc_values.append([tpr, fpr])
tpr_values, fpr_values = zip(*roc_values)

这很令人沮丧,但无论如何,它终于奏效了。