如何在过滤后获得 Tensorflow 数据集的正确基数

How to get the correct cardinality of a Tensorflow dataset after filtering

我创建了一个包含0到49个元素的TensorFlow数据集,然后过滤它只保留小于25个元素,如下

import tensorflow as tf
dataset = tf.data.Dataset.range(50) 
dataset_less_25 = dataset.filter(lambda x: x < 25)

然而,当我如下检查新数据集的基数时:

dataset_less_25.cardinality().numpy() 

它returns-2,这没有意义。我进一步检查新数据集实际上包含 25 个元素,所以我想知道为什么 cardinality() 函数在这种情况下不起作用?

检查 the docs of this method, there are special integer codes for infinite as well as unknown cardinalities. Way at the bottom,我们看到 -2 代码表示未知基数。也就是说,该方法无法推断数据集大小。实际上,filter 用作具有未知基数的数据集的示例。

为什么会这样,我不确定。深入研究代码,cardinality() 的实现是 here。这导致 gen_dataset_ops.dataset_cardinality。但是我在代码库中找不到 gen_dataset_ops 。它可能是从其他地方自动生成的文件。

我假设此方法仅执行非常基本的分析(例如,对于 range 数据集,很容易说出有多少元素)而无需实际评估任何数据集元素,如果这个简单的方法无法成功(因为不清楚哪些元素会在不实际查看元素的情况下通过过滤器),它 returns“未知”。