如何使用 GridSearchCV 的结果绘制验证曲线?

How can I plot validation curves using the results from GridSearchCV?

我正在使用 GridSearchCV 训练模型以找到最佳参数

代码:

grid_params = {
   'n_estimators': [100, 200, 300, 400],
   'criterion': ['gini', 'entropy'],
   'max_features': ['auto', 'sqrt', 'log2']
}

gs = GridSearchCV(
    RandomForestClassifier(),
    grid_params,
    cv=2,
    verbose=1,
    n_jobs=-1
)

clf = gs.fit(X_train, y_train)

这是一个缓慢的过程,在此之后,我打印了混淆矩阵,但我想绘制验证曲线以检查是否存在过度拟合,我使用以下代码:

train_scores, valid_scores = validation_curve(clf.best_estimator_, X, y)

问题是我需要设置param_nameparam_range,但我不想再训练了,因为这是一个太慢的过程。

另一种选择是使用gs,而不是clf.best_estimator_,但我需要gs受过训练,以便获得其他信息。

如何绘制验证曲线,并在没有两次训练的情况下保留 gs 训练器?

最后我使用了这个代码:

grid_params = {
   'n_estimators': [100, 200, 300, 400],
   'criterion': ['gini', 'entropy'],
   'max_features': ['auto', 'sqrt', 'log2'],
   # 'bootstrap': [True]
}

gs = GridSearchCV(
    RandomForestClassifier(),
    grid_params,
    cv=10,
    verbose=1,
    n_jobs=-1,
    return_train_score=True,
    scoring='f1_micro'
)

clf = gs.fit(X, y)

test_scores = clf.cv_results_['mean_test_score']
train_scores = clf.cv_results_['mean_train_score'] 

plt.plot(test_scores, label='test')
plt.plot(train_scores, label='train')
plt.legend(loc='best')
plt.show()

您可以使用 GridSearchCVcv_results_ 属性并获得每个超参数组合的结果。 Validation Curve 旨在描述单个参数值对训练和交叉验证分数的影响。

由于您使用 GridSearchCV 微调多个参数,我们可以创建多个图来可视化每个参数的影响。问题在于,当我们想要研究特定参数时,我们必须对其他参数进行平均。我们可以通过分别对每个参数执行 groupby 然后聚合结果来实现这一点。

我们可以取平均值,但对于标准偏差我们必须使用 pooled variance 因为每个 CV 的标准偏差几乎是恒定的。

from sklearn.datasets import make_classification
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.ensemble import RandomForestClassifier

X, y = make_classification(n_samples=1000,  
                           n_features=100, n_informative=2,
                           class_sep=0.5,random_state=42)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

grid_params = {
   'n_estimators': [10, 20, 50],
   'max_features': ['auto', 'sqrt', 'log2'],
   'criterion': ['gini', 'entropy'],
   'max_depth': [2, 5, 10]
}

gs = GridSearchCV(
    RandomForestClassifier(random_state=42),
    grid_params,
    cv=5,
    verbose=1,
    n_jobs=-1, 
    return_train_score=True # set this for train score
)

gs.fit(X_train, y_train)

import pandas  as pd
df = pd.DataFrame(gs.cv_results_)
results = ['mean_test_score',
           'mean_train_score',
           'std_test_score', 
           'std_train_score']

def pooled_var(stds):
    # https://en.wikipedia.org/wiki/Pooled_variance#Pooled_standard_deviation
    n = 5 # size of each group
    return np.sqrt(sum((n-1)*(stds**2))/ len(stds)*(n-1))

fig, axes = plt.subplots(1, len(grid_params), 
                         figsize = (5*len(grid_params), 7),
                         sharey='row')
axes[0].set_ylabel("Score", fontsize=25)


for idx, (param_name, param_range) in enumerate(grid_params.items()):
    grouped_df = df.groupby(f'param_{param_name}')[results]\
        .agg({'mean_train_score': 'mean',
              'mean_test_score': 'mean',
              'std_train_score': pooled_var,
              'std_test_score': pooled_var})

    previous_group = df.groupby(f'param_{param_name}')[results]
    axes[idx].set_xlabel(param_name, fontsize=30)
    axes[idx].set_ylim(0.0, 1.1)
    lw = 2
    axes[idx].plot(param_range, grouped_df['mean_train_score'], label="Training score",
                color="darkorange", lw=lw)
    axes[idx].fill_between(param_range,grouped_df['mean_train_score'] - grouped_df['std_train_score'],
                    grouped_df['mean_train_score'] + grouped_df['std_train_score'], alpha=0.2,
                    color="darkorange", lw=lw)
    axes[idx].plot(param_range, grouped_df['mean_test_score'], label="Cross-validation score",
                color="navy", lw=lw)
    axes[idx].fill_between(param_range, grouped_df['mean_test_score'] - grouped_df['std_test_score'],
                    grouped_df['mean_test_score'] + grouped_df['std_test_score'], alpha=0.2,
                    color="navy", lw=lw)

handles, labels = axes[0].get_legend_handles_labels()
fig.suptitle('Validation curves', fontsize=40)
fig.legend(handles, labels, loc=8, ncol=2, fontsize=20)

fig.subplots_adjust(bottom=0.25, top=0.85)  
plt.show()

注意:对于像 criterion 这样的字符串值的参数,线图不是正确的,您可以将其修改为带有误差线的条形图。