如何在不使用 Lambda 层的情况下向数据点添加维度?

How do you add a dimension to the datapoints without using Lambda layer?

我正在尝试使用 Conv2D 层对 fashion_mnist 数据集进行分类,据我所知,使用以下代码可以轻松完成此操作:

import tensorflow as tf

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

model = tf.keras.Sequential([
    tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1)),
    tf.keras.layers.Input(shape=(28,28),batch_size=32),      
    tf.keras.layers.Conv2D(4,kernel_size=3),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation="softmax")
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(x=train_images, y=train_labels, validation_data=(test_images, test_labels), epochs=10)

但是,我被要求不使用 Lambda 层。所以,上面的解决方案是不正确的。

所以,我想知道如何在不使用 Lambda 层的情况下对 mnist_fashion 数据集进行分类?

更新: 当我使用以下代码添加维度时:

train_images = train_images / 255.0
train_images = tf.expand_dims(train_images,axis=0)

test_images = test_images / 255.0
test_images = tf.expand_dims(test_images,axis=0)

和运行它针对相同的模型,我得到以下错误:

ValueError: Data cardinality is ambiguous:
  x sizes: 1
  y sizes: 60000
Make sure all arrays contain the same number of samples.

有几个选项: 直接在 train_images 上使用 expand_dims 并更改输入形状或使用 Reshape 图层而不是 Lambda 图层或完全移除图层并更改 input_shape tf.keras.layers.Input(shape=(28,28, 1),batch_size=32)。取决于你想要什么。这是 axis=-1 中的 expand_dims 选项:

train_images = train_images / 255.0
train_images = tf.expand_dims(train_images, axis=-1)

test_images = test_images / 255.0
test_images = tf.expand_dims(test_images, axis=-1)

并将Input图层更改为tf.keras.layers.Input(shape=(28, 28, 1)

Q1 答案:因为 Conv2D 层需要 3D 输入(不包括批量大小)所以类似于:(rows, cols, channels)。您的数据具有 (samples, 28, 28) 的形状。如果您的频道出现在行和列之前,您可以在 axis=1 上使用 expand_dims,从而在使用 axis=-1 时得到 (samples, 1, 28, 28) 而不是 (samples, 28, 28, 1)。如果是前者,则必须将 Conv2D 层中的 data_format 参数设置为 channels_first。使用 axis=0 导致形状 (1 samples, 28, 28),这是不正确的,因为第一个维度应保留为批次维度。

Q2 答:我用了shape=(28, 28, 1)因为fashion mnist的图都是灰度图。也就是说,它们只有一个通道(我们已明确定义)。另一方面,RGB 图像有 3 个通道:红色通道、绿色通道和蓝色通道。