使用 Keras API,如何批量导入给定批次中每个 ID 恰好 K 个实例的图像?

Using Keras APIs, how can I import images in batches with exactly K instances of each ID in a given batch?

我正在尝试实现批量硬三元组损失,如 https://arxiv.org/pdf/2004.06271.pdf 的第 3.2 节所示。

我需要导入我的图像,以便每个批次在特定批次中每个 ID 恰好有 K 个实例。因此,每批必须是K的倍数。

我的图像目录太大,无法放入内存,因此我正在使用 ImageDataGenerator.flow_from_directory() 导入图像,但我看不到此函数的任何参数来实现我需要的功能。

如何使用 Keras 实现此批处理行为?

从 Tensorflow 2.4 开始,我没有看到使用 ImageDataGenerator.

执行此操作的标准方法

所以我认为你需要在tensorflow.keras.utils.Sequenceclass的基础上自己写一个,所以你可以自由定义批处理内容。

参考文献:
https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence
https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

您可以尝试以受控方式将多个数据流合并在一起。

假设您有 tf.data.Dataset 的 K 个实例(无论您如何实例化它们)负责提供特定 ID 的训练实例,您可以将它们连接起来以在小批量中均匀分布:

ds1 = ...  # Training instances with ID == 1
ds2 = ...  # Training instances with ID == 2
...
dsK = ... # Training instances with ID == K



train_dataset = tf.data.Dataset.zip((ds1, ds2, ..., dsK)).flat_map(concat_datasets).batch(batch_size=N * K)

其中 concat_datasets 是合并函数:

def concat_datasets(*datasets):
    ds = tf.data.Dataset.from_tensors(datasets[0])
    for i in range(1, len(datasets)):
        ds = ds.concatenate(tf.data.Dataset.from_tensors(datasets[i]))
    return ds