加载 Tensorflow 数据集进行训练时,何时使用 .repeat?

When is .repeat used when loading a Tensorflow dataset for training?

我看过在对 Tensorflow 数据集进行加载、改组、映射、批处理、预取等时使用 .repeat() 的教程,而其他教程则完全跳过它。

我知道 repeat 的作用和使用方法,但无法弄清楚何时使用何时不使用它。

有什么帮助吗?

视情况而定。让我们以 MNIST 为例。假设我们使用 from_tensor_slices 构建数据集。训练数据集有 60000 个样本。

假设我们使用批量大小 100 并且不使用 repeat。这意味着数据集将提供 600 个批次。现在,如果我们尝试训练模型,例如使用 keras fit 接口,数据集将在 600 步后简单地 运行 出样本!我们将无法训练更多。使用 repeat,数据集将在 运行 出来后简单地“重新开始”,我们可以根据需要进行训练。

其他教程可能会使用手动训练循环。也许你有一个像

这样的循环
for batch in data_set:
    ...

在此示例中,如果我们不使用 repeat,循环将在 600 个批次后再次停止。但是,我们可以这样做:

for epoch in range(n_epochs):
    for batch in data_set:
        ...

在这个例子中,我们指定了n_epochs中数据集的遍历次数。内循环在 600 个批次后停止,但随后外循环(纪元)简单地递增 1,内循环再次开始。这样,即使不使用 repeat.

,我们也可以拥有 600 多个批次

最后,当然还有其他创建数据集的方法。例如,from_generator 可用于流式传输来自 Python 生成器的数据集,该生成器可以 运行 无限长,因此根本不需要 repeat

在没有看过您所指的教程的情况下,我只能猜测 repeat 使用方面的差异可以用训练循环编码方式的差异来解释,例如上述。