使用来自 ImageDataGenerator 和 tf.data.Dataset 的数据拟合模型时出错

Error when fit the model with data from ImageDataGenerator and tf.data.Dataset

我使用 ImageDataGenerator 和 tf.data.Dataset 创建了一个训练集,如下所示:

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)

ds = tf.data.Dataset.from_generator(
    lambda: img_gen.flow_from_directory(flowers, 
                                        batch_size=32, 
                                        target_size=(224,224)), 
    output_types=(tf.float32, tf.float32), 
    output_shapes=([32,224,224,3], [32,5])
)

print(ds.element_spec)

train_set = ds.prefetch(1)

然后用train_set拟合模型如下:

base_model = tf.keras.applications.xception.Xception(weights='imagenet',
                                                     include_top=False)

avg_layer = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
output_layer = tf.keras.layers.Dense(5, activation='softmax')(avg_layer)

model = tf.keras.Model(inputs=base_model.input, outputs=output_layer)

base_model.trainable=False

model.compile(loss=tf.keras.losses.CategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.Adam(),
              metrics='accuracy')

history_first = model.fit(train_set, epochs=5)

但出现以下错误:

Epoch 1/5
Found 3670 images belonging to 5 classes.
    114/Unknown - 50s 408ms/step - loss: 0.7188 - accuracy: 0.7459
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-29-c453f1f904a2> in <module>()
     11               metrics='accuracy')
     12 
---> 13 history_first = model.fit(train_set, epochs=5)

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     53     ctx.ensure_initialized()
     54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 55                                         inputs, attrs, num_outputs)
     56   except core._NotOkStatusException as e:
     57     if name is not None:

InvalidArgumentError: Graph execution error:

2 root error(s) found.
  (0) INVALID_ARGUMENT:  TypeError: `generator` yielded an element of shape (22, 224, 224, 3) where an element of shape (32, 224, 224, 3) was expected.
Traceback (most recent call last):

  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/script_ops.py", line 271, in __call__
    ret = func(*args)

  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)

  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1048, in generator_py_func
    f"`generator` yielded an element of shape {ret_array.shape} "

TypeError: `generator` yielded an element of shape (22, 224, 224, 3) where an element of shape (32, 224, 224, 3) was expected.


     [[{{node PyFunc}}]]
     [[IteratorGetNext]]
     [[IteratorGetNext/_4]]
  (1) INVALID_ARGUMENT:  TypeError: `generator` yielded an element of shape (22, 224, 224, 3) where an element of shape (32, 224, 224, 3) was expected.
Traceback (most recent call last):

  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/script_ops.py", line 271, in __call__
    ret = func(*args)

  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)

  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1048, in generator_py_func
    f"`generator` yielded an element of shape {ret_array.shape} "

TypeError: `generator` yielded an element of shape (22, 224, 224, 3) where an element of shape (32, 224, 224, 3) was expected.


     [[{{node PyFunc}}]]
     [[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_160597]

我知道最后一批只剩下 22 个例子。 我该如何解决这个错误?

还有一个问题,如何使用tf.data.Dataset的批处理功能? 此代码无效:

train_set = train_set.batch(32, drop_remainder=True).prefetch(1)

尝试使用 None 定义可变批量大小并设置 steps_per_epoch:

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)
img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)

train_set = tf.data.Dataset.from_generator(
    lambda: img_gen.flow_from_directory(flowers, 
                                        batch_size=32, 
                                        target_size=(224,224)), 
    output_types=(tf.float32, tf.float32), 
    output_shapes=([None, 224, 224, 3], [None, 5])
)

train_set = train_set.prefetch(1)

base_model = tf.keras.applications.xception.Xception(weights='imagenet',
                                                     include_top=False)

avg_layer = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
output_layer = tf.keras.layers.Dense(5, activation='softmax')(avg_layer)

model = tf.keras.Model(inputs=base_model.input, outputs=output_layer)

base_model.trainable=False

model.compile(loss=tf.keras.losses.CategoricalCrossentropy(),
              optimizer=tf.keras.optimizers.Adam(), metrics=['categorical_accuracy'])

history_first = model.fit(train_set, steps_per_epoch= 3670 // 32, epochs=5)
Epoch 1/5
Found 3670 images belonging to 5 classes.
114/114 [==============================] - 78s 593ms/step - loss: 0.7008 - categorical_accuracy: 0.7563
Epoch 2/5
114/114 [==============================] - 62s 535ms/step - loss: 0.4332 - categorical_accuracy: 0.8480
Epoch 3/5
114/114 [==============================] - 61s 534ms/step - loss: 0.3553 - categorical_accuracy: 0.8782
Epoch 4/5
114/114 [==============================] - 60s 527ms/step - loss: 0.3236 - categorical_accuracy: 0.8892
Epoch 5/5
114/114 [==============================] - 61s 535ms/step - loss: 0.3068 - categorical_accuracy: 0.8972