'Array is not a python function'在keras中构建简单线性模型时出错

'Array is not a python function' Error when building simple linear model in keras

我正在尝试使用 keras 构建一个简单的线性模型,如下所示:

lin_model = Sequential([
        Lambda(x_train, input_shape=(1,28,28)),
        Flatten(),
        Dense(10, activation='softmax')
    ])

但我不断收到以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-86-664f3eb6b96f> in <module>()
      2         Lambda(x_train, input_shape=(1,28,28)),
      3         Flatten(),
----> 4         Dense(10, activation='softmax')
      5     ])
      6 lin_model.compile(Adam(), loss='categorical_crossentropy', metrics=['accuracy'])

/home/matar/anaconda2/lib/python2.7/site-packages/keras/models.pyc in __init__(self, layers, name)
    399         if layers:
    400             for layer in layers:
--> 401                 self.add(layer)
    402 
    403     def add(self, layer):

/home/matar/anaconda2/lib/python2.7/site-packages/keras/models.pyc in add(self, layer)
    434                 # and create the node connecting the current layer
    435                 # to the input layer we just created.
--> 436                 layer(x)
    437 
    438             if len(layer.inbound_nodes) != 1:

/home/matar/anaconda2/lib/python2.7/site-packages/keras/engine/topology.pyc in __call__(self, inputs, **kwargs)
    594 
    595             # Actually call the layer, collecting output(s), mask(s), and shape(s).
--> 596             output = self.call(inputs, **kwargs)
    597             output_mask = self.compute_mask(inputs, previous_mask)
    598 

/home/matar/anaconda2/lib/python2.7/site-packages/keras/layers/core.pyc in call(self, inputs, mask)
    643     def call(self, inputs, mask=None):
    644         arguments = self.arguments
--> 645         if has_arg(self.function, 'mask'):
    646             arguments['mask'] = mask
    647         return self.function(inputs, **arguments)

/home/matar/anaconda2/lib/python2.7/site-packages/keras/utils/generic_utils.pyc in has_arg(fn, name, accept_all)
    226     """
    227     if sys.version_info < (3,):
--> 228         arg_spec = inspect.getargspec(fn)
    229         if accept_all and arg_spec.keywords is not None:
    230             return True

/home/matar/anaconda2/lib/python2.7/inspect.pyc in getargspec(func)
    813         func = func.im_func
    814     if not isfunction(func):
--> 815         raise TypeError('{!r} is not a Python function'.format(func))
    816     args, varargs, varkw = getargs(func.func_code)
    817     return ArgSpec(args, varargs, varkw, func.func_defaults)

TypeError: array([[[[        nan,         nan,         nan, ...,         nan,
                  nan,         nan],
         [        nan,         nan,         nan, ...,         nan,
                  nan,         nan],
         [        nan,         nan, -0.00408252, ..., -0.00783084,
                  nan,         nan],
         ..., 
         [        nan,         nan, -0.0066643 , ..., -0.00567531,
          -0.00408252,         nan],
         [        nan,         nan,         nan, ..., -0.00408252,
                  nan,         nan],
         [        nan,         nan,         nan, ...,         nan,
                  nan,         nan]]],


       [[[        nan,         nan,         nan, ...,         nan,
                  nan,         nan],
         [        nan,         nan,         nan, ...,         nan,
                  nan,         nan],
         [        nan,         nan, -0.00408252, ..., -0.00783084,
                  nan,         nan],
         ..., 
         [        nan,         nan, -0.0066643 , ..., -0.00567531,
          -0.00408252,         nan],
         [        nan,         nan,         nan, ..., -0.00408252,
                  nan,         nan],
         [        nan,         nan,         nan, ...,         nan,
                  nan,         nan]]],


       [[[        nan,         nan,         nan, ...,         nan,
                  nan,         nan],
         [        nan,         nan,         nan, ...,         nan,
                  nan,         nan],
         [        nan,         nan, -0.00408252, ..., -0.00783084,
                  nan,         nan],
         ..., 
         [        nan,         nan, -0.0066643 , ..., -0.00567531,
          -0.00408252,         nan],
         [        nan,         nan,         nan, ..., -0.00408252,
                  nan,         nan],
         [        nan,         nan,         nan, ...,         nan,
                  nan,         nan]]],


       ..., 
       [[[        nan,         nan,         nan, ...,         nan,
                  nan,         nan],
         [        nan,         nan,         nan, ...,         nan,
                  nan,         nan],
         [        nan,         nan, -0.00408252, ..., -0.00783084,
                  nan,         nan],
         ..., 
         [        nan,         nan, -0.0066643 , ..., -0.00567531,
          -0.00408252,         nan],
         [        nan,         nan,         nan, ..., -0.00408252,
                  nan,         nan],
         [        nan,         nan,         nan, ...,         nan,
                  nan,         nan]]],


       [[[        nan,         nan,         nan, ...,         nan,
                  nan,         nan],
         [        nan,         nan,         nan, ...,         nan,
                  nan,         nan],
         [        nan,         nan, -0.00408252, ..., -0.00783084,
                  nan,         nan],
         ..., 
         [        nan,         nan, -0.0066643 , ..., -0.00567531,
          -0.00408252,         nan],
         [        nan,         nan,         nan, ..., -0.00408252,
                  nan,         nan],
         [        nan,         nan,         nan, ...,         nan,
                  nan,         nan]]],


       [[[        nan,         nan,         nan, ...,         nan,
                  nan,         nan],
         [        nan,         nan,         nan, ...,         nan,
                  nan,         nan],
         [        nan,         nan, -0.00408252, ..., -0.00783084,
                  nan,         nan],
         ..., 
         [        nan,         nan, -0.0066643 , ..., -0.00567531,
          -0.00408252,         nan],
         [        nan,         nan,         nan, ..., -0.00408252,
                  nan,         nan],
         [        nan,         nan,         nan, ...,         nan,
                  nan,         nan]]]]) is not a Python function

如何解决这个问题?

正如您在问题名称中指出的那样:

'Array is not a python function' Error when building simple linear model in keras

x_train 是一个数组,Keras Lambda 需要一个函数:

在此处阅读更多内容:https://keras.io/layers/core/

基本上,您是在创建模型时将输入传递给层。这不是它的工作原理。

# first create model
model = Sequential()
model.add(Dense(13, input_dim=13, kernel_initializer='normal', activation='relu'))
model.add(Dense(1, kernel_initializer='normal'))

# Compile model
model.compile(loss='mean_squared_error', optimizer='adam')


# evaluate model with standardized dataset
estimator = KerasRegressor(build_fn=baseline_model, nb_epoch=100, batch_size=5, verbose=0)

results = cross_val_score(estimator, x_train, y_input)