roc_curve()函数是如何在幕后计算FPR、TPR值的。在我的例子中,我从 (400,) 维输入数据中得到 (53,)

How does roc_curve() function calculates FPR, TPR values behind the scene. In my case, I got (53,) from (400,) dimensional input data

我正在绘制 roc 曲线。我有一个带有一个隐藏层的神经网络 classifier。所以我的输出是最后一层激活函数的结果,我称之为 A2;这将是 roc_curve() 中的概率输入。我的 A2 和预测具有以下形状和数据:

print(A2.ravel().shape)
print(predictions.ravel().shape)
print(A2, predictions)

输出:

(400,)
(400,)
[[3.22246780e-04 7.64373268e-01 7.64385217e-01 7.64372464e-01
  1.63920340e-01 7.64372463e-01 2.75254103e-04 7.65185909e-01
  2.06186064e-01 2.12094433e-01 2.75251983e-04 7.64372463e-01
  2.11985152e-01 2.10202927e-01 2.75252955e-04 9.44088883e-02
  2.02522498e-01 2.07370306e-01 2.50282683e-03 2.75260253e-04
  2.11928461e-01 2.75251291e-04 2.75251291e-04 2.75251498e-04
  2.75251306e-04 1.35809613e-01 2.75464969e-04 1.74181943e-01
  2.75435676e-04 2.75251294e-04 2.96236579e-04 2.75268578e-04
  2.76053487e-04 2.78105904e-04 2.75293008e-04 2.75251307e-04
  2.87538148e-04 2.75270689e-04 2.39320951e-06 4.45134656e-02
  2.75251367e-04 2.75251506e-04 2.75251303e-04 2.31132556e-06
  3.69449012e-04 2.75251293e-04 5.59346558e-02 2.31132310e-06
  1.82980485e-01 6.20515482e-06 2.32293394e-02 1.58108674e-03
  2.75252597e-04 1.19360888e-02 2.27051743e-01 2.31161383e-06
  2.31132421e-06 2.31132310e-06 2.31573234e-06 2.31132310e-06
  5.15530179e-01 2.31132310e-06 2.31132311e-06 2.46803695e-06
  2.31132310e-06 2.31141693e-06 2.31132314e-06 2.31181353e-06
  1.08428788e-03 3.91750347e-01 2.15413251e-01 2.31136922e-06
  2.31132310e-06 2.31135038e-06 2.31132310e-06 4.18257225e-02
  2.31132310e-06 2.31692274e-06 2.31132315e-06 2.34152146e-06
  2.31132310e-06 2.31132310e-06 2.31134156e-06 2.32276423e-06
  2.31184444e-06 2.31189807e-06 2.31132310e-06 3.03902587e-06
  2.33123340e-06 6.74029292e-03 1.37374673e-04 7.11777353e-06
  2.31332212e-06 2.31134309e-06 2.85446765e-01 8.45686446e-04
  2.95393201e-06 6.30729453e-02 2.35681287e-06 1.67406531e-05
  1.39482094e-04 1.47208937e-05 2.64716376e-05 1.48764918e-05
  2.37288319e-06 1.76484186e-05 1.47209077e-05 4.24952409e-05
  2.47222738e-04 1.53198138e-05 5.10281474e-06 1.47209298e-05
  1.47208667e-05 2.64277585e-01 1.47208667e-05 1.55307243e-01
  1.47208865e-05 2.91081049e-03 1.47208667e-05 1.47208667e-05
  1.47208667e-05 1.47903704e-05 1.47238820e-05 3.11567098e-02
  4.14289114e-01 1.50836911e-05 2.78303520e-02 1.47208667e-05
  1.47251817e-05 1.47947695e-05 1.47208667e-05 1.47208667e-05
  1.47208940e-05 1.48783712e-05 2.05607558e-04 1.47208667e-05
  4.83812804e-05 1.47208667e-05 2.09377734e-01 1.49642652e-05
  1.47221481e-05 1.47568362e-05 2.77831915e-01 4.82959556e-01
  4.50969045e-01 3.82364226e-02 4.11377002e-02 2.16308926e-01
  8.88141165e-02 2.12679453e-01 2.24050631e-02 1.47208667e-05
  2.12677744e-01 2.12677744e-01 2.12677760e-01 2.33568941e-01
  2.28926909e-01 2.13773365e-01 2.12678951e-01 1.35565877e-03
  2.47656669e-01 1.08727082e-01 2.12677744e-01 2.12678014e-01
  2.12677744e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01
  2.11844159e-01 1.51525672e-03 2.12677744e-01 2.12677744e-01
  7.65697761e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01
  2.12668331e-01 2.12677744e-01 2.12677744e-01 2.12677744e-01
  2.12677744e-01 2.12677744e-01 2.12677744e-01 6.01467058e-02
  2.12677744e-01 2.12677744e-01 2.12677495e-01 2.12677744e-01
  2.12677743e-01 2.12677744e-01 7.72857608e-01 2.09249431e-01
  7.86146268e-01 7.64683696e-01 8.39288704e-01 2.12677744e-01
  8.05987357e-01 7.73524718e-01 7.64722596e-01 7.64646794e-01
  2.12677744e-01 8.54868081e-01 7.66923142e-01 8.54244158e-01
  2.11261708e-01 7.66992993e-01 2.12677744e-01 2.12598362e-01
  7.66165847e-01 9.99643109e-01 7.65268010e-01 9.99685903e-01
  9.99685903e-01 7.65043689e-01 2.12677744e-01 2.12677744e-01
  7.64840536e-01 9.99685901e-01 9.99332786e-01 2.12677743e-01
  7.79121852e-01 9.99685785e-01 7.79074180e-01 7.65194741e-01
  8.98667738e-01 9.99684795e-01 9.58419683e-01 9.99685902e-01
  9.99685882e-01 9.99639779e-01 9.99639274e-01 9.99677983e-01
  9.99685736e-01 9.99685902e-01 9.92940564e-01 9.99685903e-01
  9.99685839e-01 8.30995491e-01 9.90611316e-01 9.99997341e-01
  9.99670704e-01 9.23825584e-01 9.99685666e-01 9.99996824e-01
  9.99685902e-01 9.40290068e-01 9.99685903e-01 9.99996965e-01
  9.99685364e-01 9.99997362e-01 9.99685801e-01 9.99997362e-01
  9.99996900e-01 9.99685513e-01 9.99997362e-01 9.99684995e-01
  9.99676405e-01 9.99997362e-01 6.89410113e-01 5.28997119e-01
  9.93019339e-01 6.62017810e-01 9.99997362e-01 9.99997362e-01
  9.99997358e-01 9.99997362e-01 9.99997362e-01 9.99997346e-01
  9.99997362e-01 9.99997352e-01 9.99997362e-01 9.99997362e-01
  9.99997362e-01 9.99071790e-01 9.99997362e-01 9.99997362e-01
  3.46195433e-01 9.99995537e-01 9.99997362e-01 9.99997362e-01
  9.99997362e-01 9.99997362e-01 9.99997362e-01 9.99996894e-01
  7.67197871e-01 9.99997179e-01 1.65047845e-01 9.99978488e-01
  2.93981729e-01 9.99997362e-01 9.99997361e-01 9.29067186e-01
  9.99997362e-01 9.48399940e-01 9.99997362e-01 9.99997362e-01
  6.78299886e-01 9.99997362e-01 9.99997362e-01 9.63677152e-01
  9.99997362e-01 3.67733752e-01 9.99997222e-01 7.74993071e-01
  6.37972260e-01 9.99943783e-01 9.77268446e-01 9.99976242e-01
  7.00255679e-01 9.99983200e-01 9.99983201e-01 9.99983138e-01
  9.99983197e-01 9.86360906e-01 9.99983201e-01 9.99389801e-01
  9.98380059e-01 9.99983201e-01 9.99983201e-01 9.99983201e-01
  9.99983199e-01 9.99983199e-01 9.99983201e-01 9.99391768e-01
  9.99983201e-01 9.99983201e-01 9.99981131e-01 9.99983201e-01
  9.99983201e-01 9.76520592e-01 8.44076103e-01 9.99983201e-01
  9.99983201e-01 9.99983201e-01 9.99983201e-01 9.99983201e-01
  9.99899640e-01 9.99983201e-01 9.99983193e-01 9.99964112e-01
  9.99983201e-01 9.99983201e-01 9.99983201e-01 7.58322592e-01
  9.99983201e-01 9.99983201e-01 9.99981971e-01 7.64372463e-01
  7.64372463e-01 9.99983201e-01 9.06823611e-01 9.99983201e-01
  7.64372463e-01 9.99983201e-01 2.01516877e-01 7.64372463e-01
  3.98768426e-01 7.64372463e-01 9.81611504e-01 7.64372463e-01
  7.64372463e-01 7.64370725e-01 7.64372463e-01 7.64372463e-01
  9.99979567e-01 1.90105310e-01 7.64372463e-01 7.64372463e-01
  4.09226724e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01
  7.64372463e-01 7.64387743e-01 7.64372463e-01 7.64372463e-01
  7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01
  7.64372463e-01 7.64372463e-01 7.64372463e-01 7.64372463e-01
  7.76876797e-01 7.64372463e-01 2.07693046e-01 7.64372463e-01
  7.64372463e-01 7.59770748e-01 7.64372463e-01 7.64372463e-01
  7.66343703e-01 2.05588421e-01 7.64372828e-01 2.06636497e-01
  1.97645490e-01 2.09816835e-01 7.64372464e-01 1.77842165e-01]] [[0 1 1 1 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0
  0 0 0 0 0 0 0 0 0 0 1 0 1 1 1 0 1 1 1 1 0 1 1 1 0 1 0 0 1 1 1 1 1 1 0 0
  1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 0 1
  0 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 0 1 1 1
  1 1 1 1 1 0 1 1 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 0 1 0
  0 0 1 0]]

现在,当我在 roc_curve() 中输入这些值时,我得到以下形状和大小的 fpr、tpr、阈值:

fpr, tpr, threshold = roc_curve(Y.ravel(), A2.ravel())
print(fpr.shape, tpr.shape, threshold.shape)
print(fpr, tpr, threshold)

输出:

(53,) (53,) (53,)
[0.    0.    0.    0.    0.    0.    0.    0.    0.005 0.005 0.015 0.015
 0.025 0.025 0.03  0.03  0.035 0.035 0.05  0.05  0.06  0.06  0.065 0.065
 0.075 0.075 0.095 0.095 0.1   0.1   0.19  0.27  0.285 0.285 0.3   0.3
 0.32  0.32  0.325 0.325 0.335 0.335 0.34  0.34  0.345 0.345 0.35  0.35
 0.355 0.355 0.36  0.36  1.   ] [0.    0.015 0.045 0.19  0.23  0.235 0.245 0.615 0.615 0.62  0.62  0.64
 0.64  0.665 0.665 0.675 0.675 0.685 0.685 0.69  0.69  0.7   0.7   0.715
 0.825 0.895 0.895 0.905 0.905 0.92  0.92  0.935 0.935 0.945 0.945 0.95
 0.95  0.955 0.955 0.96  0.96  0.965 0.965 0.97  0.97  0.975 0.975 0.99
 0.99  0.995 0.995 1.    1.   ] [1.99999736e+00 9.99997362e-01 9.99997362e-01 9.99995537e-01
 9.99983201e-01 9.99983201e-01 9.99983201e-01 8.44076103e-01
 8.39288704e-01 8.30995491e-01 7.86146268e-01 7.74993071e-01
 7.72857608e-01 7.66165847e-01 7.65697761e-01 7.65194741e-01
 7.65185909e-01 7.64840536e-01 7.64646794e-01 7.64387743e-01
 7.64373268e-01 7.64372464e-01 7.64372464e-01 7.64372463e-01
 7.64372463e-01 5.28997119e-01 4.14289114e-01 3.98768426e-01
 3.91750347e-01 2.93981729e-01 2.12677744e-01 2.12677744e-01
 2.12677744e-01 2.12677743e-01 2.12668331e-01 2.12598362e-01
 2.11844159e-01 2.11261708e-01 2.10202927e-01 2.09816835e-01
 2.09249431e-01 2.07693046e-01 2.07370306e-01 2.06636497e-01
 2.06186064e-01 2.05588421e-01 2.02522498e-01 1.90105310e-01
 1.82980485e-01 1.77842165e-01 1.74181943e-01 1.65047845e-01
 2.31132310e-06]

因此我的 roc 曲线看起来像这样:

plt.figure()
plt.plot(fpr, tpr)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic(ROC Curve)')

输出:

为什么我得到 (53,) 形状的 FPR、TPR、阈值?我的案例只是简单的两个classclass化。感谢您的帮助。

阈值数计算如下:

  • 第 1 步:保持唯一分值,加 1。

Source:

# y_score typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = np.where(np.diff(y_score))[0]
threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]
  • 第 2 步(如果if drop_intermediate and len(fps) > 2):丢弃与中间点对应且与其他点共线的阈值。

Source:

# Attempt to drop thresholds corresponding to points in between and
# collinear with other points. These are always suboptimal and do not
# appear on a plotted ROC curve (and thus do not affect the AUC).
# Here np.diff(_, 2) is used as a "second derivative" to tell if there
# is a corner at the point. Both fps and tps must be tested to handle
# thresholds with multiple data points (which are combined in
# _binary_clf_curve). This keeps all cases where the point should be kept,
# but does not drop more complicated cases like fps = [1, 3, 7],
# tps = [1, 2, 4]; there is no harm in keeping too many thresholds.
if drop_intermediate and len(fps) > 2:
    optimal_idxs = np.where(np.r_[True,
                                  np.logical_or(np.diff(fps, 2),
                                                np.diff(tps, 2)),
                                  True])[0]
    fps = fps[optimal_idxs]
    tps = tps[optimal_idxs]
    thresholds = thresholds[optimal_idxs]

然后计算每个阈值的 FPR 和 TPR。