设置tensorflow时序模型输入层的形状

Setting the shape of tensorflow sequential model input layer

我正在尝试为多 class class化构建模型,但我不明白如何设置正确的输入形状。我有一个形状为 (5420, 212) 的训练集,这是我构建的模型:

model = models.Sequential()
model.add(layers.Dense(64, activation='relu', input_shape = (5420,)))
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(5, activation='softmax'))
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
history = model.fit(X_train, y_train, epochs=20, batch_size=512)

当我 运行 它时,我得到错误:

ValueError: Input 0 of layer sequential_9 is incompatible with the layer: expected axis -1 of input shape to have value 5420 but received input with shape (None, 212)

为什么?输入的值不对吗?

输入形状应等于输入X第二维的长度,而输出形状应等于输出Y第二维的长度(假设两者XY 是二维的,即它们没有更高的维度)。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from sklearn.datasets import make_classification
from sklearn.preprocessing import OneHotEncoder
tf.random.set_seed(0)

# generate some data
X, y = make_classification(n_classes=5, n_samples=5420, n_features=212, n_informative=212, n_redundant=0, random_state=42)
print(X.shape, y.shape)
# (5420, 212) (5420,)

# one-hot encode the target
Y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1))
print(X.shape, Y.shape)
# (5420, 212) (5420, 5)

# extract the input and output shapes
input_shape = X.shape[1]   
output_shape = Y.shape[1]   
print(input_shape, output_shape)
# 212 5

# define the model
model = Sequential()
model.add(Dense(64, activation='relu', input_shape=(input_shape,)))
model.add(Dense(64, activation='relu'))
model.add(Dense(output_shape, activation='softmax'))

# compile the model
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

# fit the model
history = model.fit(X, Y, epochs=3, batch_size=512)
# Epoch 1/3
# 11/11 [==============================] - 0s 1ms/step - loss: 4.8206 - accuracy: 0.2208
# Epoch 2/3
# 11/11 [==============================] - 0s 1ms/step - loss: 2.8060 - accuracy: 0.3229
# Epoch 3/3
# 11/11 [==============================] - 0s 1ms/step - loss: 2.0705 - accuracy: 0.3989