从以下位置获取 H2O 神经网络的结果:h2o.grid.grid_search.H2OGridSearch

Getting results of H2O neural nets from: h2o.grid.grid_search.H2OGridSearch

我一直在使用超参数训练神经网络,但由于收到以下错误消息而无法得出结果。 nn

错误信息:'int'对象不可迭代

代码:

      nn = H2OGridSearch(model=H2ODeepLearningEstimator,
                                   hyper_params = {
            'activation' :[ "Rectifier","Tanh","Maxout","RectifierWithDropout","TanhWithDropout","MaxoutWithDropout"],
            'hidden':[[20,20],[50,50],[30,30,30],[25,25,25,25]],            ## small network, runs faster
            'epochs':1000000,                      ## hopefully converges earlier...
            'rate' :[0.0005,0.001,0.0015,0.002,0.0025,0.003,0.0035,0.0040,0.0045,0.005],
            'score_validation_samples':10000,      ## sample the validation dataset (faster)
            'stopping_rounds':2,
            'stopping_metric':"misclassification", ## alternatives: "MSE","logloss","r2"
            'stopping_tolerance':0.01})
nn.train(train1_x, train1_y,train1)

您定义网格的方式有点问题。您只能在 hyper_params 参数中传递列表字典(每个超参数的值要网格化)。您看到 Error message: 'int' object is not iterable 错误消息的原因是因为您试图为 score_validation_samplesstopping_rounds 传递整数而不是列表。

如果有您不打算覆盖的参数,则应将它们传递给网格的 train() 方法。我还建议在进行网格搜索时使用验证框架或交叉验证,这样您就不必使用训练指标来选择最佳模型。请参见下面的示例。

import h2o
from h2o.estimators.deeplearning import H2ODeepLearningEstimator
from h2o.grid.grid_search import H2OGridSearch
h2o.init()

# Import a sample binary outcome training set into H2O
train = h2o.import_file("https://s3.amazonaws.com/erin-data/higgs/higgs_train_10k.csv")

# Identify predictors and response
x = train.columns
y = "response"
x.remove(y)

# For binary classification, response should be a factor
train[y] = train[y].asfactor()

# Execute a grid search (also do 5-fold CV)
grid = H2OGridSearch(model=H2ODeepLearningEstimator, hyper_params = {
            'activation' :["Rectifier","Tanh","Maxout","RectifierWithDropout","TanhWithDropout","MaxoutWithDropout"],
            'hidden':[[20,20],[50,50],[30,30,30],[25,25,25,25]]})
grid.train(x=x, y=y, training_frame=train, \
           score_validation_samples=10000, \
           stopping_rounds=2, \
           stopping_metric="misclassification", \
           stopping_tolerance=0.01, \
           nfolds=5)

# Look at grid results
gridperf = grid.get_grid(sort_by='mean_per_class_error')

H2O Python Grid Search tutorial 中有更多关于如何使用网格搜索的示例。