LSTM 维度不兼容

LSTM Dimension are Incompatible

我正在使用 LSTM 架构处理多分类问题。它似乎有一个不兼容的形状错误。请帮我调试模型。提前致谢。

这里我提供了模型:

# build the network
model = Sequential()
model=models.Sequential()
model.add(layers.LSTM(1024,activation='tanh',input_shape=x_train.shape[1:], return_sequences=True))
model.add(layers.LSTM(512,activation='tanh',return_sequences=True))
model.add(layers.Flatten())
model.add(layers.Dense(3,activation='sigmoid'))
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

这是此 LSTM 模型的摘要:

finished loading 7740 subjects from 3 classes
train / test split: 6192, 1548
training data shape:  (6192, 16000, 1)
training labels shape:  (6192, 3)
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm (LSTM)                  (None, 16000, 1024)       4202496   
_________________________________________________________________
lstm_1 (LSTM)                (None, 16000, 512)        3147776   
_________________________________________________________________
flatten (Flatten)            (None, 8192000)           0         
_________________________________________________________________
dense (Dense)                (None, 3)                 24576003  
=================================================================
Total params: 31,926,275
Trainable params: 31,926,275
Non-trainable params: 0

这是training/fitting:

results = model.fit(x_train, y_train, epochs = 500, batch_size=16,validation_data= (x_test, y_test))

我得到的错误:

ValueError: Input 0 of layer lstm is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: (None, 800, 20, 1)

乍一看,我认为您的问题出在这一行:

results = model.fit(x_train, x_train, epochs = 500, batch_size=16,validation_data= (x_test, x_test))

通常我们将训练数据与标签相匹配。在上一行中,您正在尝试使数据适合自身。但是你做的模型架构不是为了那个。

所以尝试改变它如下:

results = model.fit(x_train, y_train, epochs = 500, batch_size=16,validation_data= (x_test, y_test)).

这可能只是您的错字。更改这些行,看看它是否有效。