针对多类 Brier 分数而不是准确性优化分类器
Optimize classifier for multiclass Brier score instead of accuracy
我更感兴趣的是使用 Brier 分数而不是准确性来优化我的多类问题。为此,我正在使用 predict_proba() 的结果评估我的分类器,例如:
import numpy as np
probs = np.array(
[ [1, 0, 0],
[0, 1, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 0, 1]]
)
targets = np.array(
[[0.9, 0.05, 0.05],
[0.1, 0.8, 0.1],
[0.7, 0.2, 0.1],
[0.1, 0.9, 0],
[0, 0, 1],
[0.5, 0.3, 0.2],
[0.1, 0.5, 0.4],
[0.34, 0.33, 0.33]]
)
def brier_multi(targets, probs):
return np.mean(np.sum((probs - targets) ** 2, axis=1))
brier_multi(targets, probs)
是否可以在多类 Brier 分数训练期间直接优化 scikit-learns 分类器而不是准确性?
编辑:
...
pipe = Pipeline(
steps=[
("preprocessor", preprocessor),
("selector", None),
("classifier", model.get("classifier")),
]
)
def brier_multi(targets, probs):
ohe_targets = OneHotEncoder().fit_transform(targets.reshape(-1, 1))
return np.mean(np.sum(np.square(probs - ohe_targets), axis=1))
brier_multi_loss = make_scorer(
brier_multi,
greater_is_better=False,
needs_proba=True,
)
search = GridSearchCV(
estimator=pipe,
param_grid=model.get("param_grid"),
scoring=brier_multi_loss,
cv=3,
n_jobs=-1,
refit=True,
verbose=3,
)
search.fit(X_train, y_train)
...
得分为 nan
/home/andreas/.local/lib/python3.8/site-packages/sklearn/model_selection/_search.py:969: UserWarning: One or more of the test scores are non-finite: [nan nan nan nan nan nan nan nan nan]
warnings.warn(
您已经知道 scoring
参数,所以您只需要将 brier_multi
包装成 GridSearchCV
期望的格式。有一个实用程序,make_scorer
:
from sklearn.metrics import make_scorer
neg_mc_brier_score = make_scorer(
brier_multi,
greater_is_better=False,
needs_proba=True,
)
GridSearchCV(..., scoring=neg_mc_brier_score)
参见User Guide and the docs for make_scorer
。
不幸的是,这不会 运行,因为你的记分器版本需要一个 one-hot-encoded 目标数组,而 sklearn multiclass 将 y_true
作为一维数组发送。作为确保其余工作正常的 hack,您可以修改:
def brier_multi(targets, probs):
ohe_targets = OneHotEncoder().fit_transform(targets.reshape(-1, 1))
return np.mean(np.sum(np.square(probs - ohe_targets), axis=1))
但我鼓励您使它更强大(如果 类 不只是 0, 1, ..., n_classes-1
怎么办?)。
对于它的价值,sklearn 正在进行 PR 以添加多类 Brier 分数:https://github.com/scikit-learn/scikit-learn/pull/22046(请务必查看链接的 PR18699,因为它已开始开发和审查)。
我更感兴趣的是使用 Brier 分数而不是准确性来优化我的多类问题。为此,我正在使用 predict_proba() 的结果评估我的分类器,例如:
import numpy as np
probs = np.array(
[ [1, 0, 0],
[0, 1, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 0, 1]]
)
targets = np.array(
[[0.9, 0.05, 0.05],
[0.1, 0.8, 0.1],
[0.7, 0.2, 0.1],
[0.1, 0.9, 0],
[0, 0, 1],
[0.5, 0.3, 0.2],
[0.1, 0.5, 0.4],
[0.34, 0.33, 0.33]]
)
def brier_multi(targets, probs):
return np.mean(np.sum((probs - targets) ** 2, axis=1))
brier_multi(targets, probs)
是否可以在多类 Brier 分数训练期间直接优化 scikit-learns 分类器而不是准确性?
编辑:
...
pipe = Pipeline(
steps=[
("preprocessor", preprocessor),
("selector", None),
("classifier", model.get("classifier")),
]
)
def brier_multi(targets, probs):
ohe_targets = OneHotEncoder().fit_transform(targets.reshape(-1, 1))
return np.mean(np.sum(np.square(probs - ohe_targets), axis=1))
brier_multi_loss = make_scorer(
brier_multi,
greater_is_better=False,
needs_proba=True,
)
search = GridSearchCV(
estimator=pipe,
param_grid=model.get("param_grid"),
scoring=brier_multi_loss,
cv=3,
n_jobs=-1,
refit=True,
verbose=3,
)
search.fit(X_train, y_train)
...
得分为 nan
/home/andreas/.local/lib/python3.8/site-packages/sklearn/model_selection/_search.py:969: UserWarning: One or more of the test scores are non-finite: [nan nan nan nan nan nan nan nan nan]
warnings.warn(
您已经知道 scoring
参数,所以您只需要将 brier_multi
包装成 GridSearchCV
期望的格式。有一个实用程序,make_scorer
:
from sklearn.metrics import make_scorer
neg_mc_brier_score = make_scorer(
brier_multi,
greater_is_better=False,
needs_proba=True,
)
GridSearchCV(..., scoring=neg_mc_brier_score)
参见User Guide and the docs for make_scorer
。
不幸的是,这不会 运行,因为你的记分器版本需要一个 one-hot-encoded 目标数组,而 sklearn multiclass 将 y_true
作为一维数组发送。作为确保其余工作正常的 hack,您可以修改:
def brier_multi(targets, probs):
ohe_targets = OneHotEncoder().fit_transform(targets.reshape(-1, 1))
return np.mean(np.sum(np.square(probs - ohe_targets), axis=1))
但我鼓励您使它更强大(如果 类 不只是 0, 1, ..., n_classes-1
怎么办?)。
对于它的价值,sklearn 正在进行 PR 以添加多类 Brier 分数:https://github.com/scikit-learn/scikit-learn/pull/22046(请务必查看链接的 PR18699,因为它已开始开发和审查)。