Tensorflow tfrecord 未被正确读取
Tensorflow tfrecord not being read correctly
我正在尝试使用 Tensorflow 在我自己的分割数据集上训练 CNN。根据我的研究,tfRecords 似乎是最好的选择。我已经弄清楚如何写入和读取 tfRecord 数据库,但绝对没有我尝试在 Tensorflow 图中成功读取它。这是从我的数据库中成功重建图像和基本事实的片段:
data_path = 'Training/train.tfrecords' # address to save the hdf5 file
record_iterator = tf.python_io.tf_record_iterator(path=data_path)
reconstructed_images = []
reconstructed_groundtruths = []
count = 0
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
height = int(example.features.feature['height']
.int64_list
.value[0])
width = int(example.features.feature['width']
.int64_list
.value[0])
gt_string = (example.features.feature['train/groundTruth']
.bytes_list
.value[0])
image_string = (example.features.feature['train/image']
.bytes_list
.value[0])
img_1d = np.fromstring(image_string, dtype=np.uint8)
reconstructed_img = img_1d.reshape((height, width))
gt_1d = np.fromstring(gt_string, dtype=np.uint8)
reconstructed_gt = gt_1d.reshape((height, width))
reconstructed_images.append(reconstructed_img)
reconstructed_groundtruths.append(reconstructed_gt)
count += 1
这段代码成功地为我的数据库中的图像和地面实况标签提供了一个 numpy 数组列表。现在,为了尝试实际训练一些东西,我正在使用你可以找到的 MNIST 示例 here。
我已将解码函数替换为以下内容:
def decode(serialized_example):
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'height': tf.FixedLenFeature([1],tf.int64),
'width': tf.FixedLenFeature([1],tf.int64),
'train/image': tf.FixedLenFeature([], tf.string),
'train/groundTruth': tf.FixedLenFeature([], tf.string),
})
height = tf.cast(features['height'], tf.int64)
width = tf.cast(features['width'], tf.int64)
image = tf.decode_raw(features['train/image'], tf.uint8)
image.set_shape((height,width))
gt = tf.decode_raw(features['train/groundTruth'], tf.uint8)
gt.set_shape((height,width))
return image, gt
当我 运行 它时,有多个问题表明代码无法读取数据库。如上所述,我将在解析 height
的行上收到错误,其中指出
int() argument must be a string, a bytes-like object or a number, not
'Tensor'
如果我暂时将 height
和 width
设置为文字,我会在图像解析行上收到一条错误消息
Shapes (?,) and (512, 512) are not compatible
很明显,这意味着图像没有从数据库中正确读取,但我完全不知道为什么或如何修复它。有人可以告诉我我做错了什么吗?
我很幸运地找到了解决方案。显然,
image.set_shape((height,width))
应该是
image = tf.reshape(image,(height,width,1))
和 gt 类似。我不知道为什么我正在关注的 Tensorflow 教程使用 set_shape...我只能猜测它适用于 1d 但不适用于 2d 或更多?我现在可以看到它也不是张量函数,所以它不能使用像高度这样的图形相关变量,但这并不能解释为什么当我用全局替换 (height,width) 时它不起作用常数。如果有人知道,将不胜感激。
我正在尝试使用 Tensorflow 在我自己的分割数据集上训练 CNN。根据我的研究,tfRecords 似乎是最好的选择。我已经弄清楚如何写入和读取 tfRecord 数据库,但绝对没有我尝试在 Tensorflow 图中成功读取它。这是从我的数据库中成功重建图像和基本事实的片段:
data_path = 'Training/train.tfrecords' # address to save the hdf5 file
record_iterator = tf.python_io.tf_record_iterator(path=data_path)
reconstructed_images = []
reconstructed_groundtruths = []
count = 0
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
height = int(example.features.feature['height']
.int64_list
.value[0])
width = int(example.features.feature['width']
.int64_list
.value[0])
gt_string = (example.features.feature['train/groundTruth']
.bytes_list
.value[0])
image_string = (example.features.feature['train/image']
.bytes_list
.value[0])
img_1d = np.fromstring(image_string, dtype=np.uint8)
reconstructed_img = img_1d.reshape((height, width))
gt_1d = np.fromstring(gt_string, dtype=np.uint8)
reconstructed_gt = gt_1d.reshape((height, width))
reconstructed_images.append(reconstructed_img)
reconstructed_groundtruths.append(reconstructed_gt)
count += 1
这段代码成功地为我的数据库中的图像和地面实况标签提供了一个 numpy 数组列表。现在,为了尝试实际训练一些东西,我正在使用你可以找到的 MNIST 示例 here。
我已将解码函数替换为以下内容:
def decode(serialized_example):
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'height': tf.FixedLenFeature([1],tf.int64),
'width': tf.FixedLenFeature([1],tf.int64),
'train/image': tf.FixedLenFeature([], tf.string),
'train/groundTruth': tf.FixedLenFeature([], tf.string),
})
height = tf.cast(features['height'], tf.int64)
width = tf.cast(features['width'], tf.int64)
image = tf.decode_raw(features['train/image'], tf.uint8)
image.set_shape((height,width))
gt = tf.decode_raw(features['train/groundTruth'], tf.uint8)
gt.set_shape((height,width))
return image, gt
当我 运行 它时,有多个问题表明代码无法读取数据库。如上所述,我将在解析 height
的行上收到错误,其中指出
int() argument must be a string, a bytes-like object or a number, not 'Tensor'
如果我暂时将 height
和 width
设置为文字,我会在图像解析行上收到一条错误消息
Shapes (?,) and (512, 512) are not compatible
很明显,这意味着图像没有从数据库中正确读取,但我完全不知道为什么或如何修复它。有人可以告诉我我做错了什么吗?
我很幸运地找到了解决方案。显然,
image.set_shape((height,width))
应该是
image = tf.reshape(image,(height,width,1))
和 gt 类似。我不知道为什么我正在关注的 Tensorflow 教程使用 set_shape...我只能猜测它适用于 1d 但不适用于 2d 或更多?我现在可以看到它也不是张量函数,所以它不能使用像高度这样的图形相关变量,但这并不能解释为什么当我用全局替换 (height,width) 时它不起作用常数。如果有人知道,将不胜感激。