如何为 tensorflow.keras.preprocessing.text_dataset_from_directory 使用多个输入

How to use multiple inputs for tensorflow.keras.preprocessing.text_dataset_from_directory

所以我正在训练一个 CNN,它接收两个输入图像和 returns 单个值作为 GPU 上的输出。由于我有很多图像,为了将数据分批提供,我使用 tf.keras.preprocessing.text_dataset_from_directory 创建一个针对 GPU 优化的 tf.Dataset 对象。

所以基本上我的输入目录是

Class_1/
  Class1_1/
     image1.png
     image2.png
  Class1_2/
     image3.png
     image4.png
...
Class_2/
  Class2_1/
     image1.png
     image2.png
  Class2_2/
     image3.png
     image4.png

默认功能仅适用于以下结构

Class_1/
      image1.png
      image2.png
      image3.png
      image4.png
    ...
Class_2/
      image1.png
      image2.png
      image3.png
      image4.png

如有任何帮助,我们将不胜感激。

我猜你的意思是 image_dataset_from_directory 因为你加载的是图像而不是文本数据。无论哪种方式,您都无法通过这些辅助函数生成具有多个输入的批次,您可以看到 from the documentation 定义了 return 形状:

A tf.data.Dataset object.

  • If label_mode is None, it yields float32 tensors of shape (batch_size, image_size[0], image_size[1], num_channels), encoding images (see below for rules regarding num_channels).
  • Otherwise, it yields a tuple (images, labels), where images has shape (batch_size, image_size[0], image_size[1], num_channels), and labels follows the format described below.

您需要编写自己的自定义生成器函数,该函数会产生从您的数据目录加载的多个输入,然后使用您的自定义生成器调用 fit 并将 kwarg validation_data 传递给一个单独的生成器生成验证数据。 (注意:在某些旧版本的 Keras 中,您可能需要 fit_generator 而不是 fit)。

下面是一些辅助函数的模块示例,这些辅助函数可以从某些目录中读取图像并在训练中将它们呈现为多图像输入。

def _generate_batch(training):
    in1s, in2s, labels = [], [], []
    batch_tuples = _sample_batch_of_paths(training)
    for input1_path, input2_path in batch_tuples:
        # skip any exception so that image GPU batch loading isn't
        # disrupted and any faulty image is just skipped.
        try:
            in1_tmp = _load_image(
                os.path.join(INPUT1_PATH_PREFIX, input1_path),
            )
            in2_tmp = _load_image(
                os.path.join(INPUT2_PATH_PREFIX, input2_path),
            )
        except Exception as exc:
            print("Unhandled exception during image batch load. Skipping...")
            print(str(exc))
            continue
        # if no exception, both images loaded so both are added to batch.
        in1s.append(in1_tmp)
        in2s.append(in2_tmp)
        # Whatever your custom logic is to determine the label for the pair.
        labels.append(
            _label_calculation_helper(input1_path, input2_path)
        )
    in1s, in2s = map(skimage.io.concatenate_images, [in1s, in2s])
    # could also add a singleton channel dimension for grayscale images.
    # in1s = in1s[:, :, :, None]
    return [in1s, in2s], labels


def _make_generator(training=True):
    while True:
        yield _generate_batch(training)


def make_generators():
    return _make_generator(training=True), _make_generator(training=False)

助手 _load_image 可能是这样的:

def _load_image(path, is_gray=False):
    tmp = skimage.io.imread(path)
    if is_gray:
        tmp = skimage.util.img_as_float(skimage.color.rgb2gray(tmp))
    else:
        tmp = skimage.util.img_as_float(skimage.color.gray2rgb(tmp))
        if tmp.shape[-1] == 4:
            tmp = skimage.color.rgba2rgb(tmp)
    # Do other stuff here - resizing, clipping, etc.
    return tmp

从磁盘上列出的一组路径中对批次进行采样的辅助函数可能如下所示:

@lru_cache(1)
def _load_and_split_input_paths():
    training_in1s, testing_in1s = train_test_split(
        os.listdir(INPUT1_PATH_PREFIX),
        test_size=TEST_SIZE,
        random_state=RANDOM_SEED
    )
    training_in2s, testing_in2s = train_test_split(
        os.listdir(INPUT2_PATH_PREFIX),
        test_size=TEST_SIZE,
        random_state=RANDOM_SEED
    )
    return training_in1s, testing_in1s, training_in2s, testing_in2s


def _sample_batch_of_paths(training):
    training_in1s, testing_in1s, training_in2s, testing_in2s = _load_and_split_input_paths()
    if training:
        return list(zip(
            random.sample(training_in1s, BATCH_SIZE),
            random.sample(training_in2s, BATCH_SIZE)
        ))
    else:
        return list(zip(
            random.sample(testing_in1s, BATCH_SIZE),
            random.sample(testing_in2s, BATCH_SIZE)
        ))

这将从一些“输入 1”目录中随机抽取图像,并将它们与“输入 2”目录中的随机样本配对。显然,在您的用例中,您需要更改此设置,以便根据定义其配对和标签的文件结构确定性地提取数据。

终于要用这个了,调用训练代码如:

training_generator, testing_generator = make_generators()
try:
    some_compiled_model.fit(
        training_generator,
        epochs=EPOCHS,
        validation_data=testing_generator,
        callbacks=[...],
        verbose=VERBOSE,
        steps_per_epoch=STEPS_PER_EPOCH,
        validation_steps=VALIDATION_STEPS,
    )
except KeyboardInterrupt:
    pass