如何获得 tf.data.Dataset 的长度 (data_size / batch_size)?

How do I get the length(data_size / batch_size) of the tf.data.Dataset?

我想得到我的 tf.data.Dataset 的长度。 (data_size / batch_size)

在Pytorch中,我可以通过简单的代码得到这个:

length = len(data_loader)

但是,它在 tensorflow 2.0 中不起作用。

我如何得到这个?

在 TensorFlow 2.0 中,您创建一个 tf.data.Dataset 对象,即 Python 可迭代对象。

在遍历所有元素之前,您无法预先知道数据集中有多少元素。

因此,假设您以这种方式创建了一个数据集:

batch_size = 12
dataset = tf.data.Dataset.from_tensor_slices(something).batch(batch_size)

你可以这样得到批次总数:

number_of_batches = len([_ for _ in iter(dataset)])