在 TensorFlow 2.0 中,如何查看数据集中的元素数量?

In TensorFlow 2.0, how can I see the number of elements in a dataset?

当我加载数据集时,我想知道是否有任何快速方法可以找到该数据集中的样本数或批次数。我知道如果我使用 with_info=True 加载数据集,我可以看到例如 total_num_examples=6000,,但如果我拆分数据集,则此信息不可用。

目前统计样本数如下,不知有没有更好的解决方案:

train_subsplit_1, train_subsplit_2, train_subsplit_3 = tfds.Split.TRAIN.subsplit(3)

cifar10_trainsub3 = tfds.load("cifar10", split=train_subsplit_3)

cifar10_trainsub3 = cifar10_trainsub3.batch(1000)

n = 0
for i, batch in enumerate(cifar10_trainsub3.take(-1)):
    print(i, n, batch['image'].shape)
    n += len(batch['image'])

print(i, n)

如果可以知道长度,那么您可以使用:

tf.data.experimental.cardinality(dataset)

但问题是 TF 数据集本质上是延迟加载的。所以我们可能事先不知道数据集的大小。事实上,完全有可能让一个数据集代表一个无限的数据集!

如果它是一个足够小的数据集,您也可以迭代它来获得长度。我之前使用过以下丑陋的小结构,但它取决于数据集是否足够小,我们可以很乐意将其加载到内存中,而且它确实不是上面 for 循环的改进!

dataset_length = [i for i,_ in enumerate(dataset)][-1] + 1