如何根据标签跳过 tf.data.TFRecordDataset 中的某些记录?

How to skip some records from tf.data.TFRecordDataset based on their labels?

目前我正在为 ResNet 开发一个训练应用程序。该项目本身不是由我发起的,而是由其他开发人员发起的。我有一个巨大的存储空间,里面装满了源照片、数千个文件,总共 8 TB。所有照片都被分割成数百classes。有时我只需要包含 classes 的一部分,有时我需要排除它们并使用其他 classes,等等。每次我首先需要构建一个数据集,作为数据集创建的结果,我得到了一堆 tfrecord 文件,然后我就可以开始模型训练了。我想优化这个过程:构建一个包含所有照片和所有 classes 的大型数据集,然后在训练时通过它的标签过滤 (exclude/include) classes。

我找到了一个 filter() 方法并且我知道如何将它用于一个简单的数据集:

def filter(x):
    if x < -2:
        print('x < -2')
        return True
    elif x > 2:
        print('x > 2')
        return True
    else:
        print('False')
        return False

d = tf.data.Dataset.from_tensor_slices([-4, -3, -2, -1, 0, 1, 2, 3, 4])
d = d.filter(filter)

这里对我来说第一个有趣的问题是,TF 似乎以某种方式优化了一个函数。输出:

x < -2
x > 2
False

所以并不是每个print语句都会被执行。但过滤器按预期工作,我得到 [-4, -3, 3, 4] `

无论如何,我无法理解如何通过 class 标签过滤数据集。数据集由记录组成,每条记录由两张图片和一个class标签组成。这是我的测试代码:

def test_filter_function(x):
    feature_description = {
        'image_raw': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'image2_raw': tf.io.FixedLenFeature([], tf.string, default_value=''),
        'label': tf.io.FixedLenFeature([], tf.string, default_value='')
    }
    print(x)
    parsed = tf.io.parse_single_example(x, feature_description)
    print(parsed['image_raw'])
    print(parsed['label'])
    return True

dataset = tf.data.TFRecordDataset(list_of_tfrecord_files)
dataset = dataset.filter(test_filter_function)

输出为:

Tensor("args_0:0", shape=(), dtype=string)
Tensor("ParseSingleExample/ParseSingleExample:1", shape=(), dtype=string)
Tensor("ParseSingleExample/ParseSingleExample:2", shape=(), dtype=string)

什么是“args_0:0”? 为什么 parsed['image_raw']print(parsed['label']) 具有相同的 Tensor 类型,而 dtypestring? 如何获得 python 字符串类型形式的 class 标签,以及如何检查它是否在一组启用的 classes 中? 是否可以以及如何排除标签不在集合中的记录?

The first interesting question for me here is that it seems TF somehow optimizes a function.

所以在 eager execution 中,TensorFlow 在生成优化的计算图时只会 运行 第一个函数,之后它只使用 TensorFlow 函数。参见 Here

What is "args_0:0"? Why both, parsed['image_raw'] and print(parsed['label']) have the same Tensor type and it's dtype is a string? How can I get a class label in a form of python string type and how to check is it in a set of enabled classes or not? Is it possible and how to exclude records which labels not in the set?

"args_0:0" 只是张量的生成名称。图像在记录中被序列化为字符串,你可以使用 image.numpy() 将它们转换回来,同样对于标签你可以使用 byteslist 解码回来。参见 Here

所以,如果你想过滤,你需要确保你用 tf.py_funcion 包装你的函数,或者更好,只使用 tf.functions,这样你的代码可以是 运行 渴望执行。