sklearn.metrics.roc_curve 仅显示 5 fprs、tprs、阈值

sklearn.metrics.roc_curve only shows 5 fprs, tprs, thresholds

我的数组长度为 520,metrics.roc_curve 只显示了几个 fpr、tpr、阈值

这些是我的分数数组的一些值

[... 4.6719894  5.3444934  2.575739   3.5660675  3.4357991  4.195427
4.120169   5.021058   5.308503   5.3124313  4.8253884  4.7469654
5.0011086  5.170149   4.5555115  4.4109273  4.6183085  4.356304
4.413242   4.1186514  5.0573816  4.646429   5.063631   4.363433
5.431669   6.1605806  6.1510544  4.8733225  6.0209446  6.5198536
5.1457767  1.3887328  1.3165888  1.143339   1.717379   1.6670974
1.1816382  1.2497046  1.035109   1.4904765  1.195155   1.2590547
1.0998954  1.6484532  1.5722921  1.2841778  1.1058662  1.3368237
1.3262213  1.215088   1.4224783  1.046008   1.262415   1.2319984
1.2202312  1.1610713  1.2327379  1.1951761  1.8699458  0.98760885
1.6670336  1.5051543  1.2339936  1.5215651  1.534271   1.1805111
1.1587876  1.0894692  1.1936147  1.3278677  1.2409594  1.0499009... ]

我只得到了这些结果

fpr [0.         0.         0.         0.00204499 0.00204499 1.        ] 
tpr [0.         0.03225806 0.96774194 0.96774194 1.         1.        ] 
threshold [7.5198536 6.5198536 3.4357991 2.5991373 2.575739  0.8769072]

这是什么原因?

这可能取决于 roc_curve() 的参数 drop_intermediate 的默认值(默认为 true),这意味着删除次优阈值 doc here。您可以通过传递 drop_intermediate=False 来防止此类行为。

这是一个例子:

import numpy as np
try:
    from sklearn.datasets import fetch_openml
    mnist = fetch_openml('mnist_784', version=1, cache=True)   
    mnist["target"] = mnist["target"].astype(np.int8)
except ImportError:
    from sklearn.datasets import fetch_mldata 
    mnist = fetch_mldata('MNIST original')

from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_predict

X, y = mnist["data"], mnist["target"]
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

sdg_clf = SGDClassifier(random_state=42, verbose=0)
sdg_clf.fit(X_train, y_train_5)

y_scores = cross_val_predict(sdg_clf, X_train, y_train_5, cv=3, method='decision_function')

# ROC Curves

from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)

len(thresholds), len(fpr), len(tpr)
# (3472, 3472, 3472)

# for roc curves, differently than for precision/recall curves, the length of thresholds and the other outputs do depend on drop_intermediate option, meant for dropping suboptimal thresholds

fpr_, tpr_, thrs = roc_curve(y_train_5, y_scores, drop_intermediate=False)
len(fpr_), len(tpr_), len(thrs)
# (60001, 60001, 60001)