回读自定义数据集 TFRecords
Reading back a custom dataset TFRecords
我正在尝试在 TFRecords
中为 CycleGAN 模型创建自定义数据集。该模型需要一种不可用的新型数据集,因此我需要创建一个。我有一些 256x256 的 JPG 图片。在 this link 之后,我为我的图像创建了 TFrecords 文件,代码如下:
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# images input
def convert_to(images, output_directory, name):
num_examples = images.shape[0]
rows = images.shape[1]
cols = images.shape[2]
depth = 1
filename = os.path.join(output_directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image = images[index]
image_raw = images[index].tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
def read_image(file_name, images_path):
image = skimage.io.imread(images_path + file_name)
return image
def get_name(img_name):
remove_ext = img_name.split(".")[0]
name = remove_ext.split("_")
return name[0]
images_path = "data/train/"
image_list = os.listdir(images_path)
images = []
for img_name in tqdm(image_list):
tfrec_name = get_name(img_name)
print(tfrec_name)
img_data = read_image(img_name, images_path)
convert_to(img_data, "data/cat_image_tfrecords", tfrec_name+"_cat_image")
写入 TFRecords 后,我使用下面的代码读取并解码它
PHOTO_FILENAMES = tf.io.gfile.glob(str('data/cat_image_tfrecords/*.tfrecords'))
IMAGE_SIZE = [256, 256]
def decode_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = (tf.cast(image, tf.float32) / 127.5) - 1
image = tf.reshape(image, [*IMAGE_SIZE, 3])
return image
def read_tfrecord(example):
tfrecord_format = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
example = tf.io.parse_single_example(example, tfrecord_format)
image = decode_image(example['image_raw'])
return image
def load_dataset(filenames, labeled=True, ordered=False):
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
return dataset
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=False).batch(1)
example_photo = next(iter(photo_ds))
解码失败,因为我在最后一行收到以下错误
InvalidArgumentError: Expected image (JPEG, PNG, or GIF), got unknown format starting with ']LBhXKeQFVC4S=/1'
[[{{node DecodeJpeg}}]]
很明显,我在 convert_to
函数中编写 TFRecord 的方式与我在 read_tfrecord
函数中读取它的方式不匹配。但我不确定如何修复它。有什么建议吗?
编辑
@sebastian-sz 方案解决问题。我试着显示一张如下图
import matplotlib.pyplot as plt
plt.subplot(121)
plt.title('Photo')
plt.imshow(example_photo[0])
它显示图像,但我看到图像的 color/light 比原始图像暗得多。虽然不确定发生了什么。附截图。原图在底部。
我更新了 read_tfrecord
函数,如下所示,请参阅注释掉的行并解决了这个特定错误。
def read_tfrecord(example):
tfrecord_format = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
example = tf.io.parse_single_example(example, tfrecord_format)
#image = decode_image(example['image_raw'])
image = tf.io.decode_raw(example['image_raw'], tf.float32)
return image
您的代码中存在一些问题:
参数问题:
问题出在一个函数 convert_to
中,更详细地说,该函数需要一个图像列表:
(...)
image = images[index]
(...)
但是,您传递的是单张图片
convert_to(img_data, "data/cat_image_tfrecords", tfrec_name+"_cat_image")
因此,后来 image
的形状(例如)224, 3
是无效的图像形状。
要解决此问题,请更改 convert_to
以接受单个图像。
序列化问题
Skimage .tobytes
似乎不兼容。考虑使用 tf.io.encode_jpeg(image).numpy()
获取图像字节。
完整代码
我能够使用以下代码保存和读取示例图像:
# Saving
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# images input
def convert_to(image, output_directory, name):
rows = image.shape[0]
cols = image.shape[1]
depth = 1
filename = os.path.join(output_directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.compat.v1.python_io.TFRecordWriter(filename)
print(image.shape)
image_raw = tf.io.encode_jpeg(image).numpy()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
def read_image(file_name, images_path):
image = skimage.io.imread(images_path + file_name)
return image
def get_name(img_name):
remove_ext = img_name.split(".")[0]
name = remove_ext.split("_")
return name[0]
images_path = "data/train/"
image_list = os.listdir(images_path)
for img_name in tqdm(image_list):
tfrec_name = get_name(img_name)
print(tfrec_name)
img_data = read_image(img_name, images_path)
convert_to(img_data, "data/cat_image_tfrecords", tfrec_name+"_cat_image")
# Loading:
PHOTO_FILENAMES = tf.io.gfile.glob(str('data/cat_image_tfrecords/*.tfrecords'))
IMAGE_SIZE = [256, 256]
def decode_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = (tf.cast(image, tf.float32) / 127.5) - 1
# Changed this from reshape
# Consider reshape if all your images have the same shape
image = tf.image.resize(image, IMAGE_SIZE)
return image
def read_tfrecord(example):
tfrecord_format = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
example = tf.io.parse_single_example(example, tfrecord_format)
image = decode_image(example['image_raw'])
return image
def load_dataset(filenames, labeled=True, ordered=False):
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
return dataset
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=False).batch(1)
example_photo = next(iter(photo_ds))
我正在尝试在 TFRecords
中为 CycleGAN 模型创建自定义数据集。该模型需要一种不可用的新型数据集,因此我需要创建一个。我有一些 256x256 的 JPG 图片。在 this link 之后,我为我的图像创建了 TFrecords 文件,代码如下:
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# images input
def convert_to(images, output_directory, name):
num_examples = images.shape[0]
rows = images.shape[1]
cols = images.shape[2]
depth = 1
filename = os.path.join(output_directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image = images[index]
image_raw = images[index].tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
def read_image(file_name, images_path):
image = skimage.io.imread(images_path + file_name)
return image
def get_name(img_name):
remove_ext = img_name.split(".")[0]
name = remove_ext.split("_")
return name[0]
images_path = "data/train/"
image_list = os.listdir(images_path)
images = []
for img_name in tqdm(image_list):
tfrec_name = get_name(img_name)
print(tfrec_name)
img_data = read_image(img_name, images_path)
convert_to(img_data, "data/cat_image_tfrecords", tfrec_name+"_cat_image")
写入 TFRecords 后,我使用下面的代码读取并解码它
PHOTO_FILENAMES = tf.io.gfile.glob(str('data/cat_image_tfrecords/*.tfrecords'))
IMAGE_SIZE = [256, 256]
def decode_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = (tf.cast(image, tf.float32) / 127.5) - 1
image = tf.reshape(image, [*IMAGE_SIZE, 3])
return image
def read_tfrecord(example):
tfrecord_format = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
example = tf.io.parse_single_example(example, tfrecord_format)
image = decode_image(example['image_raw'])
return image
def load_dataset(filenames, labeled=True, ordered=False):
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
return dataset
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=False).batch(1)
example_photo = next(iter(photo_ds))
解码失败,因为我在最后一行收到以下错误
InvalidArgumentError: Expected image (JPEG, PNG, or GIF), got unknown format starting with ']LBhXKeQFVC4S=/1'
[[{{node DecodeJpeg}}]]
很明显,我在 convert_to
函数中编写 TFRecord 的方式与我在 read_tfrecord
函数中读取它的方式不匹配。但我不确定如何修复它。有什么建议吗?
编辑
@sebastian-sz 方案解决问题。我试着显示一张如下图
import matplotlib.pyplot as plt
plt.subplot(121)
plt.title('Photo')
plt.imshow(example_photo[0])
它显示图像,但我看到图像的 color/light 比原始图像暗得多。虽然不确定发生了什么。附截图。原图在底部。
我更新了 read_tfrecord
函数,如下所示,请参阅注释掉的行并解决了这个特定错误。
def read_tfrecord(example):
tfrecord_format = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
example = tf.io.parse_single_example(example, tfrecord_format)
#image = decode_image(example['image_raw'])
image = tf.io.decode_raw(example['image_raw'], tf.float32)
return image
您的代码中存在一些问题:
参数问题:
问题出在一个函数 convert_to
中,更详细地说,该函数需要一个图像列表:
(...)
image = images[index]
(...)
但是,您传递的是单张图片
convert_to(img_data, "data/cat_image_tfrecords", tfrec_name+"_cat_image")
因此,后来 image
的形状(例如)224, 3
是无效的图像形状。
要解决此问题,请更改 convert_to
以接受单个图像。
序列化问题
Skimage .tobytes
似乎不兼容。考虑使用 tf.io.encode_jpeg(image).numpy()
获取图像字节。
完整代码
我能够使用以下代码保存和读取示例图像:
# Saving
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# images input
def convert_to(image, output_directory, name):
rows = image.shape[0]
cols = image.shape[1]
depth = 1
filename = os.path.join(output_directory, name + '.tfrecords')
print('Writing', filename)
writer = tf.compat.v1.python_io.TFRecordWriter(filename)
print(image.shape)
image_raw = tf.io.encode_jpeg(image).numpy()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'image_raw': _bytes_feature(image_raw)}))
writer.write(example.SerializeToString())
def read_image(file_name, images_path):
image = skimage.io.imread(images_path + file_name)
return image
def get_name(img_name):
remove_ext = img_name.split(".")[0]
name = remove_ext.split("_")
return name[0]
images_path = "data/train/"
image_list = os.listdir(images_path)
for img_name in tqdm(image_list):
tfrec_name = get_name(img_name)
print(tfrec_name)
img_data = read_image(img_name, images_path)
convert_to(img_data, "data/cat_image_tfrecords", tfrec_name+"_cat_image")
# Loading:
PHOTO_FILENAMES = tf.io.gfile.glob(str('data/cat_image_tfrecords/*.tfrecords'))
IMAGE_SIZE = [256, 256]
def decode_image(image):
image = tf.image.decode_jpeg(image, channels=3)
image = (tf.cast(image, tf.float32) / 127.5) - 1
# Changed this from reshape
# Consider reshape if all your images have the same shape
image = tf.image.resize(image, IMAGE_SIZE)
return image
def read_tfrecord(example):
tfrecord_format = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
example = tf.io.parse_single_example(example, tfrecord_format)
image = decode_image(example['image_raw'])
return image
def load_dataset(filenames, labeled=True, ordered=False):
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
return dataset
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=False).batch(1)
example_photo = next(iter(photo_ds))