加载 tensorflow 图像并创建补丁

Load tensorflow images and create patches

我正在使用 image_dataset_from_directory to load a very large RGB imagery dataset from disk into a Dataset。例如,

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    <directory>,
    label_mode=None,
    seed=1,
    subset='training',
    validation_split=0.1)

数据集有 100000 张图像,分为大小为 32 的批次,产生 tf.data.Dataset 规格 (batch=32, width=256, height=256, channels=3)

我想从图像中提取补丁以创建新的 tf.data.Dataset 图像空间尺寸,例如 64x64。

因此,我想创建一个新的数据集,其中包含 400000 个补丁,仍然以 32 个为一组,tf.data.Dataset 规格为 (batch=32, width=64, height=64, channels=3)

我查看了 window method and the extract_patches 函数,但文档中并不清楚如何使用它们来创建新的数据集,我需要开始对补丁进行训练。 window 似乎适用于一维张量,而 extract_patches 似乎适用于数组而不适用于数据集。

有什么关于如何完成这个的建议吗?

更新:

只是为了澄清我的需求。我试图避免在磁盘上手动创建补丁。第一,这在磁盘方面是站不住脚的。二、补丁大小不固定。实验将在几个补丁大小上进行。所以,我不想在磁盘上手动创建补丁,也不想在内存中手动加载图像并执行补丁。我希望让 tensorflow 将补丁创建作为管道工作流的一部分来处理,以最大限度地减少磁盘和内存使用。

我相信您可以使用 python class 生成器。如果需要,您可以将此生成器传递给 model.fit 函数。我实际上用过一次标签预处理。

我编写了以下数据集生成器,它从您的数据集中加载一个批次,根据 tile_shape 参数将批次中的图像拆分为多个图像。如果有足够的图像,则返回下一批。

在示例中,为了简化,我使用了一个简单的数据集from_tensor_slices。当然,你也可以换成你的。

import tensorflow as tf

class TileDatasetGenerator:
    
    def __init__(self, dataset, batch_size, tile_shape):
        self.dataset_iterator = iter(dataset)
        self.batch_size = batch_size
        self.tile_shape = tile_shape
        self.image_queue = None
    
    def __iter__(self):
        return self
    
    def __next__(self):
        if self._has_queued_enough_for_batch():
            return self._dequeue_batch()
        
        batch = next(self.dataset_iterator)
        self._split_images(batch)    
        return self.__next__()
            
    def _has_queued_enough_for_batch(self):
        return self.image_queue is not None and tf.shape(self.image_queue)[0] >= self.batch_size
    
    def _dequeue_batch(self):
        batch, remainder = tf.split(self.image_queue, [self.batch_size, -1], axis=0)
        self.image_queue = remainder
        return batch
        
    def _split_images(self, batch):
        batch_shape = tf.shape(batch)
        batch_splitted = tf.reshape(batch, shape=[-1, self.tile_shape[0], self.tile_shape[1], batch_shape[-1]])
        if self.image_queue is None:
            self.image_queue = batch_splitted
        else:
            self.image_queue = tf.concat([self.image_queue, batch_splitted], axis=0)
            


dataset = tf.data.Dataset.from_tensor_slices(tf.ones(shape=[128, 64, 64, 3]))
dataset.batch(32)
generator = TileDatasetGenerator(dataset, batch_size = 16, tile_shape = [32,32])

for batch in generator:
    tf.print(tf.shape(batch))

编辑: 如果需要,可以将生成器转换为 tf.data.Dataset,但它需要向生成器添加一个 __call__ 函数,返回一个迭代器(在本例中为 self)。

new_dataset = tf.data.Dataset.from_generator(generator, output_types=(tf.int64))

您要找的是tf.image.extract_patches。这是一个例子:

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

data = tfds.load('mnist', split='test', as_supervised=True)

get_patches = lambda x, y: (tf.reshape(
    tf.image.extract_patches(
        images=tf.expand_dims(x, 0),
        sizes=[1, 14, 14, 1],
        strides=[1, 14, 14, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'), (4, 14, 14, 1)), y)

data = data.map(get_patches)

fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images, labels = next(iter(data))
for index, image in enumerate(images):
    ax = plt.subplot(2, 2, index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image)
plt.show()