如何在 sklearn 中使用分层交叉验证处理多类

How to handle multiclass with Stratified Cross Validation in sklearn

from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_fscore_support as score
from sklearn.model_selection import StratifiedKFold
from xgboost import XGBClassifier
import time

params = {
    'min_child_weight': [1, 5, 10],
    'gamma': [0.5, 1, 1.5, 2, 5],
    'subsample': [0.6, 0.8, 1.0],
    'colsample_bytree': [0.6, 0.8, 1.0],
    'max_depth': [3, 4, 5]
    }



xgb = XGBClassifier(learning_rate=0.02, n_estimators=600,
                silent=True, nthread=1)

folds = 5
param_comb = 5

skf = StratifiedKFold(n_splits=folds, shuffle = True, random_state = 1001)

random_search = RandomizedSearchCV(xgb, param_distributions=params, n_iter=param_comb, scoring=['f1_macro','precision_macro'], n_jobs=4, cv=skf.split(X_train,y_train), verbose=3, random_state=1001)

start_time = time.clock() # timing starts from this point for "start_time" variable
random_search.fit(X_train, y_train)
elapsed = (time.clock() - start) # timing ends here for "start_time" 
variable

我的代码在上面,我的 y_train 是一个 pandas 多类系列,整数从 0 到 9。

y_train.head()
1041    8
1177    7
2966    0
1690    2
2115    1
Name: Industry, dtype: object

一旦我运行上面的设置代码,我收到错误消息:

ValueError: Supported target types are: ('binary', 'multiclass'). Got 'unknown' instead.

我搜索了其他类似的问题,我尝试使用 sklearn.model_selection 中的 cross_validate 并尝试使用其他与 multiclass 兼容的指标,但仍然得到相同的错误消息。

我是否可以通过基于性能指标的分层交叉验证对参数进行网格搜索?

更新:修复dtype问题后,我想将多个指标传递给scoring=,我尝试过这种方式,因为我阅读了这篇文档(http://scikit-learn.org/stable/modules/model_evaluation.html#scoring-parameter):

random_search = RandomizedSearchCV(xgb, param_distributions=params, n_iter=param_comb, scoring=['f1_macro','precision_macro'], n_jobs=4, cv=skf.split(X_train,y_train), verbose=3, random_state=1001)

然后我失败了,警告如下:

ValueError                                Traceback (most recent call 
last)
<ipython-input-67-dd57cd97c89c> in <module>()
 36 # Here we go
 37 start_time = time.clock() # timing starts from this point for 
"start_time" variable
---> 38 random_search.fit(X_train, y_train)
 39 elapsed = (time.clock() - start) # timing ends here for "start_time" variable

/anaconda3/lib/python3.6/site- 
packages/sklearn/model_selection/_search.py in fit(self, X, y, groups, 
**fit_params)
609                                  "available for that metric. If 
this is not "
610                                  "needed, refit should be set to 
False "
--> 611                                  "explicitly. %r was passed." % 
self.refit)
612             else:
613                 refit_metric = self.refit

ValueError: For multi-metric scoring, the parameter refit must be set 
to a scorer key to refit an estimator with the best parameter setting 
on the whole data and make the best_* attributes available for that 
metric. If this is not needed, refit should be set to False explicitly. 
True was passed.

如何解决这个问题?

写成here in user guide:

When specifying multiple metrics, the refit parameter must be set to the metric (string) for which the best_params_ will be found and used to build the best_estimator_ on the whole dataset. If the search should not be refit, set refit=False. Leaving refit to the default value None will result in an error when using multiple metrics.

由于您在此处使用多个指标:

random_search = RandomizedSearchCV(xgb, param_distributions=params,
                                   n_iter=param_comb, 
                                   scoring=['f1_macro','precision_macro'], 
                                   n_jobs=4, 
                                   cv=skf.split(X_train,y_train), 
                                   verbose=3, random_state=1001)

RandomizedSearchCV 不知道如何找到最佳参数。它不能从两种不同的评分策略中选择最佳分数。因此,您需要指定希望它用于查找最佳参数的评分类型。

为此,您需要将 refit 参数设置为您在 scoring 中使用的选项之一。像这样:

random_search = RandomizedSearchCV(xgb, param_distributions=params,
                                   ...
                                   scoring=['f1_macro','precision_macro'], 
                                   ...
                                   refit = 'f1_macro')