如何在 Tensorflow 中按张量形状过滤数据集

How to filter dataset by tensor shape in Tensorflow

我已经从 tfds.load 加载了一个数据集,并想丢弃某些干扰 training/are 对我没有用(例如,太小)的图像。

似乎在任何地方都没有关于这个特定问题的信息,所以我选择了数据集上最合适的 .filter(predicate) 。不幸的是,谓词的输入具有不确定的形状 (None, None, 3) 并且正如预期的那样会引发错误,即 'int' 无法与 'NoneType'.[=11= 进行比较]

是否可以在 tensorflow 中解决这个问题,或者我不应该浪费时间吗?

伪代码

ds_train = tfds.load('name')
ds_train = ds_train.map(lambda ds: ds['image'])
ds_train = ds_train.filter(lambda image: image.shape[0] >= 256)

使用 tf.data.Dataset 编写代码时,您应该使用 tf.shape(tensor) 而不是 tensor.shape,因为 tf.data.Dataset 在图形模式下工作。

引用 tf.shape 的文档:

tf.shape and Tensor.shape should be identical in eager mode. Within tf.function or within a compat.v1 context, not all dimensions may be known until execution time. Hence when defining custom layers and models for graph mode, prefer the dynamic tf.shape(x) over the static x.shape.

ds_train = ds_train.filter(lambda image: tf.shape(image)[0] >= 256)