plot_precision_recall_curve() 函数绘图的平均精度是多少?
What average precision is the plot_precision_recall_curve() function plotting?
在使用 scikit learn 的 plot_precision_recall_curve() 之后,我想知道这个函数使用的平均精度是多少。在查看文档时,这是我找到的二进制目标:
# %%
# Compute the average precision score
# ...................................
from sklearn.metrics import average_precision_score
average_precision = average_precision_score(y_test, y_score)
print('Average precision-recall score: {0:0.2f}'.format(
average_precision))
这是我的数据:
clf_4 = svm.SVC()
clf_4.fit(X_train, y_train)
y_clf_4 = clf_4.predict(X_test)
y1_test = np.array([1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1]
y1_clf4 = np.array([0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1]
average_precision_5 = average_precision_score(y1_test, y1_clf4)
average_precision_5
Out: 0.5625
现在我们使用 plot_precision_recall_curve 和 X_test 是这样的(同上):
X_test= np.array([[0.01167537, 0.04676259, 0.02145552, 0.015625 , 0. ,
0. , 0. , 0.5 , 0.01020408, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00478415, 0.01258993, 0.06759886, 0.09375 , 0. ,
0. , 0. , 0.43421053, 0. , 1. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.01503446, 0.04136691, 0.02600806, 0.015625 , 0. ,
0. , 1. , 0.13157895, 0.02721088, 0. ,
0. , 0. , 0. , 0. , 0. ,
1. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.017396 , 0.04856115, 0.07737383, 0.046875 , 0. ,
0. , 0. , 0.44736842, 0.04421769, 0. ,
0. , 1. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
1. , 0. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.0072882 , 0.01079137, 0.07866155, 0.078125 , 1. ,
0. , 0. , 0.63157895, 0. , 1. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00733909, 0.0323741 , 0.0487578 , 0.046875 , 0. ,
0. , 0. , 0.44736842, 0.02040816, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 1. , 0. , 0. , 1. ,
0. , 1. , 0. , 0. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. ],
[0.02579371, 0.11151079, 0.03639438, 0.0625 , 0. ,
0. , 0. , 0.53947368, 0.02380952, 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00203581, 0.03417266, 0.12611863, 0.125 , 0. ,
0. , 0. , 0.05263158, 0.00680272, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00527275, 0.03057554, 0.0344563 , 0.03125 , 0. ,
0. , 1. , 0.09210526, 0.00680272, 1. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00590385, 0.02158273, 0.05135926, 0.046875 , 0. ,
0. , 0. , 0.43421053, 0.00340136, 1. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1. ,
0. , 0. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.01910608, 0.16366906, 0.05917014, 0.03125 , 1. ,
0. , 1. , 0.28947368, 0.12244898, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 1. , 0. , 0. , 1. ,
0. , 0. , 1. , 0. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. ],
[0.12737045, 0.13669065, 0.07280827, 0.078125 , 1. ,
0. , 0. , 0.46052632, 0.07823129, 0. ,
0. , 1. , 0. , 0. , 0. ,
0. , 1. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
1. , 1. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. ],
[0.0537861 , 0.17446043, 0.14109651, 0.078125 , 0. ,
0. , 0. , 0.32894737, 0.08843537, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 0. , 1. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.01027066, 0.05755396, 0.06110172, 0.078125 , 1. ,
0. , 0. , 0.30263158, 0.01360544, 1. ,
0. , 0. , 0. , 0. , 1. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.0085504 , 0.01978417, 0.03185484, 0.03125 , 1. ,
1. , 0. , 0.51315789, 0.00340136, 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. , 1. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. ],
[0.02224122, 0.05215827, 0.06370968, 0.0625 , 0. ,
0. , 0. , 0.47368421, 0.04081633, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1. ,
0. , 0. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00896774, 0.05035971, 0.00974896, 0.015625 , 0. ,
0. , 0. , 0.5 , 0.02721088, 0. ,
0. , 0. , 0. , 0. , 0. ,
1. , 0. , 0. , 0. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.03302084, 0.07014388, 0.00779787, 0.015625 , 1. ,
1. , 0. , 0.25 , 0.03741497, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00630083, 0.06115108, 0.01495838, 0. , 0. ,
0. , 0. , 0.10526316, 0.00340136, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00951741, 0.03776978, 0.13261576, 0.140625 , 1. ,
1. , 0. , 0.47368421, 0.0170068 , 1. ,
0. , 1. , 0. , 0. , 0. ,
0. , 1. , 0. , 0. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ]])
现在我们可以使用 plot_precision_recall_curve 函数并打印两个结果,它们不同:
disp = plot_precision_recall_curve(clf_4, X_test, y1_test)
disp.ax_.set_title(f'2-class Precision-Recall curve:{average_precision_5}')
那么差异从何而来?
average_precision_score
的y_score
参数需要是概率估计(或类似的连续得分),而不是硬分类结果。所以你的 average_precision_5
不正确。
在使用 scikit learn 的 plot_precision_recall_curve() 之后,我想知道这个函数使用的平均精度是多少。在查看文档时,这是我找到的二进制目标:
# %%
# Compute the average precision score
# ...................................
from sklearn.metrics import average_precision_score
average_precision = average_precision_score(y_test, y_score)
print('Average precision-recall score: {0:0.2f}'.format(
average_precision))
这是我的数据:
clf_4 = svm.SVC()
clf_4.fit(X_train, y_train)
y_clf_4 = clf_4.predict(X_test)
y1_test = np.array([1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1]
y1_clf4 = np.array([0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1]
average_precision_5 = average_precision_score(y1_test, y1_clf4)
average_precision_5
Out: 0.5625
现在我们使用 plot_precision_recall_curve 和 X_test 是这样的(同上):
X_test= np.array([[0.01167537, 0.04676259, 0.02145552, 0.015625 , 0. ,
0. , 0. , 0.5 , 0.01020408, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00478415, 0.01258993, 0.06759886, 0.09375 , 0. ,
0. , 0. , 0.43421053, 0. , 1. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.01503446, 0.04136691, 0.02600806, 0.015625 , 0. ,
0. , 1. , 0.13157895, 0.02721088, 0. ,
0. , 0. , 0. , 0. , 0. ,
1. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.017396 , 0.04856115, 0.07737383, 0.046875 , 0. ,
0. , 0. , 0.44736842, 0.04421769, 0. ,
0. , 1. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
1. , 0. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.0072882 , 0.01079137, 0.07866155, 0.078125 , 1. ,
0. , 0. , 0.63157895, 0. , 1. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00733909, 0.0323741 , 0.0487578 , 0.046875 , 0. ,
0. , 0. , 0.44736842, 0.02040816, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 1. , 0. , 0. , 1. ,
0. , 1. , 0. , 0. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. ],
[0.02579371, 0.11151079, 0.03639438, 0.0625 , 0. ,
0. , 0. , 0.53947368, 0.02380952, 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00203581, 0.03417266, 0.12611863, 0.125 , 0. ,
0. , 0. , 0.05263158, 0.00680272, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00527275, 0.03057554, 0.0344563 , 0.03125 , 0. ,
0. , 1. , 0.09210526, 0.00680272, 1. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00590385, 0.02158273, 0.05135926, 0.046875 , 0. ,
0. , 0. , 0.43421053, 0.00340136, 1. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1. ,
0. , 0. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.01910608, 0.16366906, 0.05917014, 0.03125 , 1. ,
0. , 1. , 0.28947368, 0.12244898, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 1. , 0. , 0. , 1. ,
0. , 0. , 1. , 0. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. ],
[0.12737045, 0.13669065, 0.07280827, 0.078125 , 1. ,
0. , 0. , 0.46052632, 0.07823129, 0. ,
0. , 1. , 0. , 0. , 0. ,
0. , 1. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
1. , 1. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. ],
[0.0537861 , 0.17446043, 0.14109651, 0.078125 , 0. ,
0. , 0. , 0.32894737, 0.08843537, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 0. , 1. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.01027066, 0.05755396, 0.06110172, 0.078125 , 1. ,
0. , 0. , 0.30263158, 0.01360544, 1. ,
0. , 0. , 0. , 0. , 1. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.0085504 , 0.01978417, 0.03185484, 0.03125 , 1. ,
1. , 0. , 0.51315789, 0.00340136, 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. , 1. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. ],
[0.02224122, 0.05215827, 0.06370968, 0.0625 , 0. ,
0. , 0. , 0.47368421, 0.04081633, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 1. ,
0. , 0. , 0. , 1. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00896774, 0.05035971, 0.00974896, 0.015625 , 0. ,
0. , 0. , 0.5 , 0.02721088, 0. ,
0. , 0. , 0. , 0. , 0. ,
1. , 0. , 0. , 0. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.03302084, 0.07014388, 0.00779787, 0.015625 , 1. ,
1. , 0. , 0.25 , 0.03741497, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00630083, 0.06115108, 0.01495838, 0. , 0. ,
0. , 0. , 0.10526316, 0.00340136, 0. ,
0. , 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 1. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ],
[0.00951741, 0.03776978, 0.13261576, 0.140625 , 1. ,
1. , 0. , 0.47368421, 0.0170068 , 1. ,
0. , 1. , 0. , 0. , 0. ,
0. , 1. , 0. , 0. , 0. ,
0. , 0. , 1. , 0. , 0. ,
0. , 1. , 0. , 1. , 0. ,
0. , 0. , 0. , 0. , 0. ,
0. ]])
现在我们可以使用 plot_precision_recall_curve 函数并打印两个结果,它们不同:
disp = plot_precision_recall_curve(clf_4, X_test, y1_test)
disp.ax_.set_title(f'2-class Precision-Recall curve:{average_precision_5}')
那么差异从何而来?
average_precision_score
的y_score
参数需要是概率估计(或类似的连续得分),而不是硬分类结果。所以你的 average_precision_5
不正确。