如何使用 LSTM 生成序列?

How to generate sequence using LSTM?

我想在激活特定输入时生成一个序列。我想根据其对应的输入神经元激活生成奇数或偶数序列。我正在尝试使用 LSTM 创建模型,因为它可以记住短期订单。

我这样试过

import numpy as np
from keras.models import Sequential
from keras.layers import Dense,LSTM


X=np.array([[1,0],
            [0,1]])

Y=np.array([[1,3,5,7,9],
            [2,4,6,8,10]])

model = Sequential()
model.add(Dense(10, input_shape=(2))
model.add(LSTM(5, return_sequences=True))
model.add(LSTM(5, return_sequences=False))
model.add(Dense(5))
model.compile(loss='mse', optimizer='adam')

model.fit(X,Y)

但是当我试图拟合模型时它给我这个错误

NameError: name 'model' is not defined

model.add(Dense(10, input_shape=(2)) 更改为

model.add(Dense(10, input_shape=(2,)))

model.add(Dense(5)) # Remove this 

注意下面两个代码是等价的:

model = Sequential()
model.add(Dense(32, input_shape=(2,)))

model = Sequential()
model.add(Dense(32, input_dim=2))

要在 Keras 中使用 RNN,您需要在数据中引入一个额外的维度:时间步长。在您的情况下,您希望有 5 个时间步长。因为您希望输入和输出数据之间存在一对多关系,所以您需要将输入数据复制 5 次。最后一个 LSTM 层也必须设置为 return 序列,因为您需要每个时间步的结果,而不仅仅是最后一个。为了让 Dense 层知道时间维度,你需要用 TimeDistributed 层包裹它们。最后一个 Dense 层只有一个输出,因为它每个时间步只输出一个结果。

import numpy as np
from keras.models import Sequential
from keras.layers import Dense,LSTM
from keras.layers.wrappers import TimeDistributed

X=np.array([[[1, 0],
        [1, 0],
        [1, 0],
        [1, 0],
        [1, 0]],

       [[0, 1],
        [0, 1],
        [0, 1],
        [0, 1],
        [0, 1]]])


Y=np.array([[[ 1],
        [ 3],
        [ 5],
        [ 7],
        [ 9]],

       [[ 2],
        [ 4],
        [ 6],
        [ 8],
        [10]]])


model = Sequential()
model.add(TimeDistributed(Dense(10), input_shape=(5, 2)))
model.add(LSTM(5, return_sequences=True))
model.add(LSTM(5, return_sequences=True))
model.add(TimeDistributed(Dense(1)))
model.compile(loss='mse', optimizer='adam')

model.fit(X,Y, nb_epoch=4000)

model.predict(X)

有了这个,我在大约 4000 个纪元后得到以下结果:

Epoch 4000/4000
2/2 [==============================] - 0s - loss: 0.0032
Out[20]:
array([[[ 1.02318883],
        [ 2.96530271],
        [ 5.03490496],
        [ 6.99484348],
        [ 9.00506973]],

       [[ 2.05096436],
        [ 3.96090508],
        [ 5.98824072],
        [ 8.0701828 ],
        [ 9.85805798]]], dtype=float32)