Tensorflow:不兼容的形状:[1,2] 与 [1,4,4,2048]

Tensorflow: Incompatible shapes: [1,2] vs. [1,4,4,2048]

我有以下张量流模型:

img_width, img_height = 120, 120

dg = DataGenerator('/mnt/e/Shared/Stfc/Images', target_size=(img_height, img_width), batch_size=1)

input_tensor = tf.keras.Input(shape=(img_width, img_height, 3))
base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_tensor=input_tensor)

model = base_model
optimizer = tf.keras.optimizers.RMSprop(0.001)

model.compile(loss='mse',
              optimizer=optimizer,
              metrics=['mae', 'mse'])

model.fit(dg)

为了缩小问题范围,我对此进行了一些简化,

当我 运行 这样做时,我收到以下错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError:  Incompatible shapes: [1,2] vs. [1,4,4,2048]
         [[node mean_squared_error/SquaredDifference (defined at /projects/tensorflow/stfcxy.py:130) ]] [Op:__inference_train_function_15679]

这个错误似乎总是出现在不同的输入图像上。我所有的图片都是一模一样的尺寸

我正在使用 tensorflow 2.4.1

我错过了什么?

ResNet50 模型输出一个形状为 (4,4,2048) 的张量,而您期望的形状为 (2,),因此您肯定必须通过应用来减小该张量的大小进一步致密层。这是一个简单的工作示例,但我建议使用具有更多层的深层网络。

import tensorflow as tf

img_width, img_height = 120, 120

input_tensor = tf.keras.Input(shape=(img_width, img_height, 3))
base_model = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_tensor=input_tensor)
x = tf.keras.layers.GlobalMaxPool2D()(base_model.output)
output = tf.keras.layers.Dense(2, activation='linear')(x)

model = tf.keras.Model(base_model.input, output)
optimizer = tf.keras.optimizers.RMSprop(0.001)

model.compile(loss='mse',
              optimizer=optimizer,
              metrics=['mae', 'mse'])

samples = 20
images = tf.random.normal((samples, 120, 120, 3))
x_y_coords = tf.random.normal((samples, 2))
model.fit(images, x_y_coords, batch_size=2, epochs=5)
Epoch 1/5
10/10 [==============================] - 20s 689ms/step - loss: 547.9037 - mae: 16.8050 - mse: 547.9037
Epoch 2/5
10/10 [==============================] - 7s 685ms/step - loss: 560.1724 - mae: 17.3702 - mse: 560.1724
Epoch 3/5
10/10 [==============================] - 7s 694ms/step - loss: 166.5985 - mae: 8.9817 - mse: 166.5985
Epoch 4/5
10/10 [==============================] - 7s 684ms/step - loss: 169.9773 - mae: 8.6677 - mse: 169.9773
Epoch 5/5
10/10 [==============================] - 7s 684ms/step - loss: 201.1059 - mae: 9.6540 - mse: 201.1059
<keras.callbacks.History at 0x7fcaae3e5890>