加载 tensorflow 图像并创建补丁
Load tensorflow images and create patches
我正在使用 image_dataset_from_directory to load a very large RGB imagery dataset from disk into a Dataset。例如,
dataset = tf.keras.preprocessing.image_dataset_from_directory(
<directory>,
label_mode=None,
seed=1,
subset='training',
validation_split=0.1)
数据集有 100000 张图像,分为大小为 32 的批次,产生 tf.data.Dataset
规格 (batch=32, width=256, height=256, channels=3)
我想从图像中提取补丁以创建新的 tf.data.Dataset
图像空间尺寸,例如 64x64。
因此,我想创建一个新的数据集,其中包含 400000 个补丁,仍然以 32 个为一组,tf.data.Dataset
规格为 (batch=32, width=64, height=64, channels=3)
我查看了 window method and the extract_patches 函数,但文档中并不清楚如何使用它们来创建新的数据集,我需要开始对补丁进行训练。 window
似乎适用于一维张量,而 extract_patches
似乎适用于数组而不适用于数据集。
有什么关于如何完成这个的建议吗?
更新:
只是为了澄清我的需求。我试图避免在磁盘上手动创建补丁。第一,这在磁盘方面是站不住脚的。二、补丁大小不固定。实验将在几个补丁大小上进行。所以,我不想在磁盘上手动创建补丁,也不想在内存中手动加载图像并执行补丁。我希望让 tensorflow 将补丁创建作为管道工作流的一部分来处理,以最大限度地减少磁盘和内存使用。
我相信您可以使用 python class 生成器。如果需要,您可以将此生成器传递给 model.fit
函数。我实际上用过一次标签预处理。
我编写了以下数据集生成器,它从您的数据集中加载一个批次,根据 tile_shape
参数将批次中的图像拆分为多个图像。如果有足够的图像,则返回下一批。
在示例中,为了简化,我使用了一个简单的数据集from_tensor_slices。当然,你也可以换成你的。
import tensorflow as tf
class TileDatasetGenerator:
def __init__(self, dataset, batch_size, tile_shape):
self.dataset_iterator = iter(dataset)
self.batch_size = batch_size
self.tile_shape = tile_shape
self.image_queue = None
def __iter__(self):
return self
def __next__(self):
if self._has_queued_enough_for_batch():
return self._dequeue_batch()
batch = next(self.dataset_iterator)
self._split_images(batch)
return self.__next__()
def _has_queued_enough_for_batch(self):
return self.image_queue is not None and tf.shape(self.image_queue)[0] >= self.batch_size
def _dequeue_batch(self):
batch, remainder = tf.split(self.image_queue, [self.batch_size, -1], axis=0)
self.image_queue = remainder
return batch
def _split_images(self, batch):
batch_shape = tf.shape(batch)
batch_splitted = tf.reshape(batch, shape=[-1, self.tile_shape[0], self.tile_shape[1], batch_shape[-1]])
if self.image_queue is None:
self.image_queue = batch_splitted
else:
self.image_queue = tf.concat([self.image_queue, batch_splitted], axis=0)
dataset = tf.data.Dataset.from_tensor_slices(tf.ones(shape=[128, 64, 64, 3]))
dataset.batch(32)
generator = TileDatasetGenerator(dataset, batch_size = 16, tile_shape = [32,32])
for batch in generator:
tf.print(tf.shape(batch))
编辑:
如果需要,可以将生成器转换为 tf.data.Dataset
,但它需要向生成器添加一个 __call__ 函数,返回一个迭代器(在本例中为 self)。
new_dataset = tf.data.Dataset.from_generator(generator, output_types=(tf.int64))
您要找的是tf.image.extract_patches
。这是一个例子:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
data = tfds.load('mnist', split='test', as_supervised=True)
get_patches = lambda x, y: (tf.reshape(
tf.image.extract_patches(
images=tf.expand_dims(x, 0),
sizes=[1, 14, 14, 1],
strides=[1, 14, 14, 1],
rates=[1, 1, 1, 1],
padding='VALID'), (4, 14, 14, 1)), y)
data = data.map(get_patches)
fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images, labels = next(iter(data))
for index, image in enumerate(images):
ax = plt.subplot(2, 2, index + 1)
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(image)
plt.show()
我正在使用 image_dataset_from_directory to load a very large RGB imagery dataset from disk into a Dataset。例如,
dataset = tf.keras.preprocessing.image_dataset_from_directory(
<directory>,
label_mode=None,
seed=1,
subset='training',
validation_split=0.1)
数据集有 100000 张图像,分为大小为 32 的批次,产生 tf.data.Dataset
规格 (batch=32, width=256, height=256, channels=3)
我想从图像中提取补丁以创建新的 tf.data.Dataset
图像空间尺寸,例如 64x64。
因此,我想创建一个新的数据集,其中包含 400000 个补丁,仍然以 32 个为一组,tf.data.Dataset
规格为 (batch=32, width=64, height=64, channels=3)
我查看了 window method and the extract_patches 函数,但文档中并不清楚如何使用它们来创建新的数据集,我需要开始对补丁进行训练。 window
似乎适用于一维张量,而 extract_patches
似乎适用于数组而不适用于数据集。
有什么关于如何完成这个的建议吗?
更新:
只是为了澄清我的需求。我试图避免在磁盘上手动创建补丁。第一,这在磁盘方面是站不住脚的。二、补丁大小不固定。实验将在几个补丁大小上进行。所以,我不想在磁盘上手动创建补丁,也不想在内存中手动加载图像并执行补丁。我希望让 tensorflow 将补丁创建作为管道工作流的一部分来处理,以最大限度地减少磁盘和内存使用。
我相信您可以使用 python class 生成器。如果需要,您可以将此生成器传递给 model.fit
函数。我实际上用过一次标签预处理。
我编写了以下数据集生成器,它从您的数据集中加载一个批次,根据 tile_shape
参数将批次中的图像拆分为多个图像。如果有足够的图像,则返回下一批。
在示例中,为了简化,我使用了一个简单的数据集from_tensor_slices。当然,你也可以换成你的。
import tensorflow as tf
class TileDatasetGenerator:
def __init__(self, dataset, batch_size, tile_shape):
self.dataset_iterator = iter(dataset)
self.batch_size = batch_size
self.tile_shape = tile_shape
self.image_queue = None
def __iter__(self):
return self
def __next__(self):
if self._has_queued_enough_for_batch():
return self._dequeue_batch()
batch = next(self.dataset_iterator)
self._split_images(batch)
return self.__next__()
def _has_queued_enough_for_batch(self):
return self.image_queue is not None and tf.shape(self.image_queue)[0] >= self.batch_size
def _dequeue_batch(self):
batch, remainder = tf.split(self.image_queue, [self.batch_size, -1], axis=0)
self.image_queue = remainder
return batch
def _split_images(self, batch):
batch_shape = tf.shape(batch)
batch_splitted = tf.reshape(batch, shape=[-1, self.tile_shape[0], self.tile_shape[1], batch_shape[-1]])
if self.image_queue is None:
self.image_queue = batch_splitted
else:
self.image_queue = tf.concat([self.image_queue, batch_splitted], axis=0)
dataset = tf.data.Dataset.from_tensor_slices(tf.ones(shape=[128, 64, 64, 3]))
dataset.batch(32)
generator = TileDatasetGenerator(dataset, batch_size = 16, tile_shape = [32,32])
for batch in generator:
tf.print(tf.shape(batch))
编辑:
如果需要,可以将生成器转换为 tf.data.Dataset
,但它需要向生成器添加一个 __call__ 函数,返回一个迭代器(在本例中为 self)。
new_dataset = tf.data.Dataset.from_generator(generator, output_types=(tf.int64))
您要找的是tf.image.extract_patches
。这是一个例子:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
data = tfds.load('mnist', split='test', as_supervised=True)
get_patches = lambda x, y: (tf.reshape(
tf.image.extract_patches(
images=tf.expand_dims(x, 0),
sizes=[1, 14, 14, 1],
strides=[1, 14, 14, 1],
rates=[1, 1, 1, 1],
padding='VALID'), (4, 14, 14, 1)), y)
data = data.map(get_patches)
fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images, labels = next(iter(data))
for index, image in enumerate(images):
ax = plt.subplot(2, 2, index + 1)
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(image)
plt.show()