如何根据标签跳过 tf.data.TFRecordDataset 中的某些记录?
How to skip some records from tf.data.TFRecordDataset based on their labels?
目前我正在为 ResNet 开发一个训练应用程序。该项目本身不是由我发起的,而是由其他开发人员发起的。我有一个巨大的存储空间,里面装满了源照片、数千个文件,总共 8 TB。所有照片都被分割成数百classes。有时我只需要包含 classes 的一部分,有时我需要排除它们并使用其他 classes,等等。每次我首先需要构建一个数据集,作为数据集创建的结果,我得到了一堆 tfrecord
文件,然后我就可以开始模型训练了。我想优化这个过程:构建一个包含所有照片和所有 classes 的大型数据集,然后在训练时通过它的标签过滤 (exclude/include) classes。
我找到了一个 filter()
方法并且我知道如何将它用于一个简单的数据集:
def filter(x):
if x < -2:
print('x < -2')
return True
elif x > 2:
print('x > 2')
return True
else:
print('False')
return False
d = tf.data.Dataset.from_tensor_slices([-4, -3, -2, -1, 0, 1, 2, 3, 4])
d = d.filter(filter)
这里对我来说第一个有趣的问题是,TF 似乎以某种方式优化了一个函数。输出:
x < -2
x > 2
False
所以并不是每个print
语句都会被执行。但过滤器按预期工作,我得到 [-4, -3, 3, 4]
`
无论如何,我无法理解如何通过 class 标签过滤数据集。数据集由记录组成,每条记录由两张图片和一个class标签组成。这是我的测试代码:
def test_filter_function(x):
feature_description = {
'image_raw': tf.io.FixedLenFeature([], tf.string, default_value=''),
'image2_raw': tf.io.FixedLenFeature([], tf.string, default_value=''),
'label': tf.io.FixedLenFeature([], tf.string, default_value='')
}
print(x)
parsed = tf.io.parse_single_example(x, feature_description)
print(parsed['image_raw'])
print(parsed['label'])
return True
dataset = tf.data.TFRecordDataset(list_of_tfrecord_files)
dataset = dataset.filter(test_filter_function)
输出为:
Tensor("args_0:0", shape=(), dtype=string)
Tensor("ParseSingleExample/ParseSingleExample:1", shape=(), dtype=string)
Tensor("ParseSingleExample/ParseSingleExample:2", shape=(), dtype=string)
什么是“args_0:0”?
为什么 parsed['image_raw']
和 print(parsed['label'])
具有相同的 Tensor 类型,而 dtype
是 string
?
如何获得 python 字符串类型形式的 class 标签,以及如何检查它是否在一组启用的 classes 中?
是否可以以及如何排除标签不在集合中的记录?
The first interesting question for me here is that it seems TF somehow optimizes a function.
所以在 eager execution 中,TensorFlow 在生成优化的计算图时只会 运行 第一个函数,之后它只使用 TensorFlow 函数。参见 Here
What is "args_0:0"? Why both, parsed['image_raw'] and print(parsed['label']) have the same Tensor type and it's dtype is a string? How can I get a class label in a form of python string type and how to check is it in a set of enabled classes or not? Is it possible and how to exclude records which labels not in the set?
"args_0:0
" 只是张量的生成名称。图像在记录中被序列化为字符串,你可以使用 image.numpy()
将它们转换回来,同样对于标签你可以使用 byteslist
解码回来。参见 Here。
所以,如果你想过滤,你需要确保你用 tf.py_funcion
包装你的函数,或者更好,只使用 tf.functions
,这样你的代码可以是 运行 渴望执行。
目前我正在为 ResNet 开发一个训练应用程序。该项目本身不是由我发起的,而是由其他开发人员发起的。我有一个巨大的存储空间,里面装满了源照片、数千个文件,总共 8 TB。所有照片都被分割成数百classes。有时我只需要包含 classes 的一部分,有时我需要排除它们并使用其他 classes,等等。每次我首先需要构建一个数据集,作为数据集创建的结果,我得到了一堆 tfrecord
文件,然后我就可以开始模型训练了。我想优化这个过程:构建一个包含所有照片和所有 classes 的大型数据集,然后在训练时通过它的标签过滤 (exclude/include) classes。
我找到了一个 filter()
方法并且我知道如何将它用于一个简单的数据集:
def filter(x):
if x < -2:
print('x < -2')
return True
elif x > 2:
print('x > 2')
return True
else:
print('False')
return False
d = tf.data.Dataset.from_tensor_slices([-4, -3, -2, -1, 0, 1, 2, 3, 4])
d = d.filter(filter)
这里对我来说第一个有趣的问题是,TF 似乎以某种方式优化了一个函数。输出:
x < -2
x > 2
False
所以并不是每个print
语句都会被执行。但过滤器按预期工作,我得到 [-4, -3, 3, 4]
`
无论如何,我无法理解如何通过 class 标签过滤数据集。数据集由记录组成,每条记录由两张图片和一个class标签组成。这是我的测试代码:
def test_filter_function(x):
feature_description = {
'image_raw': tf.io.FixedLenFeature([], tf.string, default_value=''),
'image2_raw': tf.io.FixedLenFeature([], tf.string, default_value=''),
'label': tf.io.FixedLenFeature([], tf.string, default_value='')
}
print(x)
parsed = tf.io.parse_single_example(x, feature_description)
print(parsed['image_raw'])
print(parsed['label'])
return True
dataset = tf.data.TFRecordDataset(list_of_tfrecord_files)
dataset = dataset.filter(test_filter_function)
输出为:
Tensor("args_0:0", shape=(), dtype=string)
Tensor("ParseSingleExample/ParseSingleExample:1", shape=(), dtype=string)
Tensor("ParseSingleExample/ParseSingleExample:2", shape=(), dtype=string)
什么是“args_0:0”?
为什么 parsed['image_raw']
和 print(parsed['label'])
具有相同的 Tensor 类型,而 dtype
是 string
?
如何获得 python 字符串类型形式的 class 标签,以及如何检查它是否在一组启用的 classes 中?
是否可以以及如何排除标签不在集合中的记录?
The first interesting question for me here is that it seems TF somehow optimizes a function.
所以在 eager execution 中,TensorFlow 在生成优化的计算图时只会 运行 第一个函数,之后它只使用 TensorFlow 函数。参见 Here
What is "args_0:0"? Why both, parsed['image_raw'] and print(parsed['label']) have the same Tensor type and it's dtype is a string? How can I get a class label in a form of python string type and how to check is it in a set of enabled classes or not? Is it possible and how to exclude records which labels not in the set?
"args_0:0
" 只是张量的生成名称。图像在记录中被序列化为字符串,你可以使用 image.numpy()
将它们转换回来,同样对于标签你可以使用 byteslist
解码回来。参见 Here。
所以,如果你想过滤,你需要确保你用 tf.py_funcion
包装你的函数,或者更好,只使用 tf.functions
,这样你的代码可以是 运行 渴望执行。