Keras - ImageDataGenerator 如何批量获取标签?

Keras - ImageDataGenerator How to get batch of labels?

我的模型需要两个单独的输入:图像和标签。但是使用 ImageDataGenerator flow_from_dataframe 我只能获得包含图像和标签的完整批次。我该怎么办?

问题是 flow_from_dataframe 似乎只能接受数据框中的一列作为 x。您可以将 flow_from_dataframe 包装在 tf.data.Dataset.from_generator 中,并使用 tf.data.Dataset.map 将您的标签也作为输入。这是一个使用 flow_from_directory:

的例子
import matplotlib.pyplot as plt

BATCH_SIZE = 32

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)

ds = tf.data.Dataset.from_generator(
    lambda: img_gen.flow_from_directory(flowers, batch_size=BATCH_SIZE, shuffle=True),
    output_types=(tf.float32, tf.float32))

ds = ds.map(lambda x, y: ((x, y), y))

for x, y in ds.take(1):
  input1, input2 = x
  print(input1.shape, input2.shape)
Found 3670 images belonging to 5 classes.
(32, 256, 256, 3) (32, 5)

或者您可以使用 tf.keras.utils.image_dataset_from_directory:

import tensorflow as tf
import pathlib

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 32

ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(180, 180),
  batch_size=batch_size)

ds = ds.map(lambda x, y: ((x, y), y))

for x, y in ds.take(1):
  input1, input2 = x
  print(input1.shape, input2.shape)