创建 TF 数据集时无法解析 TFRecords
Fail to parse TFRecords while creating TF dataset
我正在尝试编写代码来解析 TFRecords 并创建 TF 数据集。我从图像列表创建 TFRecords 文件,并且能够读回它并成功解码我的图像。我的代码基于此 blog 中的示例。但是当我尝试读取我的 TFRecords 文件并创建 TF 数据集时,它失败并出现以下错误:
ValueError: Argument must be a dense tensor: FixedLenFeature(shape=[], dtype=tf.int64, default_value=None) - got shape [3], but wanted [3, 0]
以下是尝试创建数据集的代码摘要:
dataset = tf.data.TFRecordDataset(fnames)
dataset = dataset.map(parse_tfrec)
其中parse_tfrec
是解析单个proto记录的函数:
def parse_tfrec(example_proto):
features={
'height': tf.FixedLenFeature([], tf.int64, default_value=IMG_SHAPE[0]),
'width': tf.FixedLenFeature([], tf.int64, default_value=IMG_SHAPE[1]),
'depth': tf.FixedLenFeature([], tf.int64, default_value=IMG_SHAPE[2]),
'label': tf.FixedLenFeature([], tf.int64, default_value=0),
'image': tf.FixedLenFeature([], tf.string, default_value=''),
}
parsed_features = tf.parse_single_example(example_proto, features)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
depth = tf.cast(features['depth'], tf.int32)
label = tf.cast(features['label'], tf.int32)
image = tf.decode_raw(features['image'], tf.uint8)
image_shape = tf.pack([height, width, depth])
image = tf.reshape(image, image_shape)
return image, label
代码在尝试从 TFRecords(或任何其他存储的整数)解析 height
时失败。而且,我不确定我是否理解有关形状的失败消息。
有什么建议吗?
能否详细说明错误发生在哪一行?它是否出现在 'parse_single_example' 行?还是在下一行?
我注意到的一件事是,在您的 cast 语句中,您使用的是 features
字典,而不是 parsed_features
。
将您的代码更改为类似这样的内容可能会解决您的问题:
height = tf.cast(parsed_features['height'], tf.int32)
如果问题仍然存在,请告诉我。我最近花了一整天自己调试 tfrecords :) 一开始它们可能很难掌握,但最终我能够在批处理生成时间中获得巨大的性能提升。
我正在尝试编写代码来解析 TFRecords 并创建 TF 数据集。我从图像列表创建 TFRecords 文件,并且能够读回它并成功解码我的图像。我的代码基于此 blog 中的示例。但是当我尝试读取我的 TFRecords 文件并创建 TF 数据集时,它失败并出现以下错误:
ValueError: Argument must be a dense tensor: FixedLenFeature(shape=[], dtype=tf.int64, default_value=None) - got shape [3], but wanted [3, 0]
以下是尝试创建数据集的代码摘要:
dataset = tf.data.TFRecordDataset(fnames)
dataset = dataset.map(parse_tfrec)
其中parse_tfrec
是解析单个proto记录的函数:
def parse_tfrec(example_proto):
features={
'height': tf.FixedLenFeature([], tf.int64, default_value=IMG_SHAPE[0]),
'width': tf.FixedLenFeature([], tf.int64, default_value=IMG_SHAPE[1]),
'depth': tf.FixedLenFeature([], tf.int64, default_value=IMG_SHAPE[2]),
'label': tf.FixedLenFeature([], tf.int64, default_value=0),
'image': tf.FixedLenFeature([], tf.string, default_value=''),
}
parsed_features = tf.parse_single_example(example_proto, features)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
depth = tf.cast(features['depth'], tf.int32)
label = tf.cast(features['label'], tf.int32)
image = tf.decode_raw(features['image'], tf.uint8)
image_shape = tf.pack([height, width, depth])
image = tf.reshape(image, image_shape)
return image, label
代码在尝试从 TFRecords(或任何其他存储的整数)解析 height
时失败。而且,我不确定我是否理解有关形状的失败消息。
有什么建议吗?
能否详细说明错误发生在哪一行?它是否出现在 'parse_single_example' 行?还是在下一行?
我注意到的一件事是,在您的 cast 语句中,您使用的是 features
字典,而不是 parsed_features
。
将您的代码更改为类似这样的内容可能会解决您的问题:
height = tf.cast(parsed_features['height'], tf.int32)
如果问题仍然存在,请告诉我。我最近花了一整天自己调试 tfrecords :) 一开始它们可能很难掌握,但最终我能够在批处理生成时间中获得巨大的性能提升。