从 sklearn 使用 GridSearchCV 获取 ValueError
Getting ValueError using GridSearchCV from sklearn
我想使用 sklearn 的 GridSearchCV 优化学习率,然后再优化模型的其他超参数。你可以在下面看到我的代码。不幸的是,我总是收到错误:ValueError: learning_rate is not a legal parameter
这里也有类似的问题("ValueError: activation is not a legal parameter" with Keras classifier or learning_rate is not a legal parameter),但对我没有帮助。我也在 lr 或 learn_rate 中更改了 learning_rate,但它没有用。
# Sequential API
def create_model(learn_rate=0.01):
model = Sequential()
model.add(Dense(128, activation='relu'))
model.add(Dense(32, activation='relu'))
model.add(Dense(1))
opt = keras.optimizers.Adam(lr=learn_rate)
model.compile(optimizer=opt,
loss='mean_squared_error',
metrics=['mae', 'mean_absolute_percentage_error'])
return model
# Hyperparameter Tuning
model = KerasRegressor(build_fn=create_model(), verbose=0)
param_grid = {'learning_rate': [0.001, 0.01, 0.1]}
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=5)
grid.fit(X_train, Y_train)
print(grid.best_params_)
希望有人能帮助我解决我的问题。
如果您在 create_model
函数和参数网格 param_grid
中都使用 'learning rate'
作为参数,并且将 create_model()
替换为create_model
在 KerasRegressor
.
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.wrappers.scikit_learn import KerasRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import make_regression
# Sequential API
def create_model(learning_rate=0.01):
model = Sequential()
model.add(Dense(128, activation='relu'))
model.add(Dense(32, activation='relu'))
model.add(Dense(1))
model.compile(
optimizer=Adam(learning_rate=learning_rate),
loss='mean_squared_error',
metrics=['mae', 'mean_absolute_percentage_error']
)
return model
# Sample data
X, y = make_regression(n_samples=1000, n_features=10, random_state=100)
# Hyperparameter Tuning
estimator = KerasRegressor(build_fn=create_model, verbose=0)
param_grid = {'learning_rate': [0.001, 0.01, 0.1]}
grid = GridSearchCV(estimator=estimator, param_grid=param_grid, cv=5)
grid.fit(X, y)
print(grid.best_params_)
# {'learning_rate': 0.1}
我想使用 sklearn 的 GridSearchCV 优化学习率,然后再优化模型的其他超参数。你可以在下面看到我的代码。不幸的是,我总是收到错误:ValueError: learning_rate is not a legal parameter
这里也有类似的问题("ValueError: activation is not a legal parameter" with Keras classifier or learning_rate is not a legal parameter),但对我没有帮助。我也在 lr 或 learn_rate 中更改了 learning_rate,但它没有用。
# Sequential API
def create_model(learn_rate=0.01):
model = Sequential()
model.add(Dense(128, activation='relu'))
model.add(Dense(32, activation='relu'))
model.add(Dense(1))
opt = keras.optimizers.Adam(lr=learn_rate)
model.compile(optimizer=opt,
loss='mean_squared_error',
metrics=['mae', 'mean_absolute_percentage_error'])
return model
# Hyperparameter Tuning
model = KerasRegressor(build_fn=create_model(), verbose=0)
param_grid = {'learning_rate': [0.001, 0.01, 0.1]}
grid = GridSearchCV(estimator=model, param_grid=param_grid, cv=5)
grid.fit(X_train, Y_train)
print(grid.best_params_)
希望有人能帮助我解决我的问题。
如果您在 create_model
函数和参数网格 param_grid
中都使用 'learning rate'
作为参数,并且将 create_model()
替换为create_model
在 KerasRegressor
.
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.wrappers.scikit_learn import KerasRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import make_regression
# Sequential API
def create_model(learning_rate=0.01):
model = Sequential()
model.add(Dense(128, activation='relu'))
model.add(Dense(32, activation='relu'))
model.add(Dense(1))
model.compile(
optimizer=Adam(learning_rate=learning_rate),
loss='mean_squared_error',
metrics=['mae', 'mean_absolute_percentage_error']
)
return model
# Sample data
X, y = make_regression(n_samples=1000, n_features=10, random_state=100)
# Hyperparameter Tuning
estimator = KerasRegressor(build_fn=create_model, verbose=0)
param_grid = {'learning_rate': [0.001, 0.01, 0.1]}
grid = GridSearchCV(estimator=estimator, param_grid=param_grid, cv=5)
grid.fit(X, y)
print(grid.best_params_)
# {'learning_rate': 0.1}