使用 tensorflow 时在异常中包含 tfrecord 名称
Include tfrecord name in exception when using tensorflow
我正在尝试通过 tensorflow 数据集管道构建一些调试代码。基本上,如果某个文件的 tfrecord 解析失败,我希望能够找出是哪个文件。我的梦想是 运行 在我的 parsing_function 中提供一些断言,如果它们失败了则提供文件名。
我的管道看起来像这样:
tf.data.Dataset.from_tensor_slices(file_list)
.apply(tf.contrib.data.parallel_interleave(lambda f: tf.data.TFRecordDataset(f), cycle_length=4))
.map(parse_func, num_parallel_calls=params.num_cores)
.map(_func_for_other_stuff)
理想情况下,我会在 parallel_interleave 步骤中传递文件名,但如果我有匿名函数 return 一个文件名,tfrecordataset 元组,我得到:
TypeError: `map_func` must return a `Dataset` object.
我也曾尝试将文件名包含在文件本身中,例如 问题,但我在这里遇到了问题,因为文件名的长度是可变的。
函数传给tf.contrib.data.parallel_interleave()
must be a tf.data.Dataset
. Therefore you can solve this by attaching the filename tensor to each element of the TFRecordDataset
, using tf.data.Dataset.zip()
的return值如下:
def read_records_func(filename):
records = tf.data.TFRecordDataset(filename)
# Create a dataset from the filename tensor and repeat it indefinitely.
filename_as_dataset = tf.data.Dataset.from_tensors(filename).repeat(None)
return tf.data.Dataset.zip((filename_as_dataset, records))
dataset = (tf.data.Dataset.from_tensor_slices(file_list)
.apply(tf.contrib.data.parallel_interleave(read_records_func, cycle_length=4))
.map(parse_func, num_parallel_calls=params.num_cores)
.map(_func_for_other_stuff))
我正在尝试通过 tensorflow 数据集管道构建一些调试代码。基本上,如果某个文件的 tfrecord 解析失败,我希望能够找出是哪个文件。我的梦想是 运行 在我的 parsing_function 中提供一些断言,如果它们失败了则提供文件名。
我的管道看起来像这样:
tf.data.Dataset.from_tensor_slices(file_list)
.apply(tf.contrib.data.parallel_interleave(lambda f: tf.data.TFRecordDataset(f), cycle_length=4))
.map(parse_func, num_parallel_calls=params.num_cores)
.map(_func_for_other_stuff)
理想情况下,我会在 parallel_interleave 步骤中传递文件名,但如果我有匿名函数 return 一个文件名,tfrecordataset 元组,我得到:
TypeError: `map_func` must return a `Dataset` object.
我也曾尝试将文件名包含在文件本身中,例如
函数传给tf.contrib.data.parallel_interleave()
must be a tf.data.Dataset
. Therefore you can solve this by attaching the filename tensor to each element of the TFRecordDataset
, using tf.data.Dataset.zip()
的return值如下:
def read_records_func(filename):
records = tf.data.TFRecordDataset(filename)
# Create a dataset from the filename tensor and repeat it indefinitely.
filename_as_dataset = tf.data.Dataset.from_tensors(filename).repeat(None)
return tf.data.Dataset.zip((filename_as_dataset, records))
dataset = (tf.data.Dataset.from_tensor_slices(file_list)
.apply(tf.contrib.data.parallel_interleave(read_records_func, cycle_length=4))
.map(parse_func, num_parallel_calls=params.num_cores)
.map(_func_for_other_stuff))