LSTM 的输入维度

Input dimension for LSTM

我有这样的模型。

它根据最后 5 个条目预测下一个条目。

(表示给出item[0]~item[4]预测item[4])

我有训练数据(94,11026)

所以我将 (93, 5, 11026) 作为输入并给出 (93, 11026) 作为验证。 (然后将其分开用于测试和值)

    n_hidden = 512
    model = Sequential()
    model.add(LSTM(n_hidden, activation=None, input_shape=(5,11026), return_sequences=True))
    model.add(Dense(n_hidden, activation="linear")) 
    model.add(Dense(n_in, activation="linear"))
    opt = Adam(lr=0.001)
    model.compile(loss='mse', optimizer=opt)
    model.summary()

然后 历史 = model.fit(x, y, epochs=epoch, batch_size=10,validation_data=(val_x, val_y))

然而这个错误发生了。

我想它说我应该给形状 [10,11026] 而不是 [10,5,11026]

但是我想做的是给他们最后 5 个并预测下一个。

为什么错了? (我认为它适用于 simpleRNN)

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm (LSTM)                  (None, 5, 512)            23631872  
_________________________________________________________________
dense (Dense)                (None, 5, 512)            262656    
_________________________________________________________________
dense_1 (Dense)              (None, 5, 11026)          5656338   
=================================================================
Total params: 29,550,866
Trainable params: 29,550,866
Non-trainable params: 0
_________________________________________________________________
Traceback (most recent call last):
  File "manage.py", line 22, in <module>
    main()
  File "manage.py", line 18, in main
    execute_from_command_line(sys.argv)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/django/core/management/__init__.py", line 401, in execute_from_command_line
    utility.execute()
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/django/core/management/__init__.py", line 395, in execute
    self.fetch_command(subcommand).run_from_argv(self.argv)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/django/core/management/base.py", line 330, in run_from_argv
    self.execute(*args, **cmd_options)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/django/core/management/base.py", line 371, in execute
    output = self.handle(*args, **options)
  File "/Users/whitebear/CodingWorks/httproot/aiwave/defapp/management/commands/learn.py", line 205, in handle
    history = model.fit(f.x, f.y, epochs=epoch, batch_size=10,validation_data=(f.val_x, f.val_y))
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
    tmp_logs = train_function(iterator)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py", line 840, in _call
    return self._stateless_fn(*args, **kwds)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 2829, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1848, in _filtered_call
    cancellation_manager=cancellation_manager)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1924, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 550, in call
    ctx=ctx)
  File "/Users/whitebear/anaconda3/envs/aiwave/lib/python3.6/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Incompatible shapes: [10,5,11026] vs. [10,11026]
     [[node gradient_tape/mean_squared_error/BroadcastGradientArgs (defined at /Users/whitebear/CodingWorks/httproot/aiwave/defapp/management/commands/learn.py:205) ]] [Op:__inference_train_function_2084]

Function call stack:
train_function

您的模型预测了 5 个项目而不是 1 个。

尝试:

model.add(LSTM(n_hidden, activation=None, input_shape=(5,11026)))