如何从pytorch dataloader获取批量迭代的总数?

How to get the total number of batch iteration from pytorch dataloader?

我有一个问题,如何从 pytorch dataloader 获取批量迭代的总数?

以下是训练的常用代码

for i, batch in enumerate(dataloader):

那么,有什么方法可以得到for循环的总迭代次数吗?

在我的NLP问题中,总迭代次数与int(n_train_samples/batch_size)不同...

例如,如果我只截断训练数据 10,000 个样本并将批量大小设置为 1024,那么在我的 NLP 问题中会发生 363 次迭代。

我想知道如何获取“for-loop”中的总迭代次数。

谢谢。

len(dataloader)returns批次总数。这取决于数据集的 __len__ 函数,因此请确保它设置正确。

创建数据加载器时有一个附加参数。它被称为drop_last

如果 drop_last=True 则长度为 number_of_training_examples // batch_size。 如果 drop_last=False 可能是 number_of_training_examples // batch_size +1 .

BS=128
ds_train = torchvision.datasets.CIFAR10('/data/cifar10', download=True, train=True, transform=t_train)
dl_train = DataLoader( ds_train, batch_size=BS, drop_last=True, shuffle=True)

对于预定义的数据集,您可能会得到如下示例的数量:

# number of examples
len(dl_train.dataset) 

dataloader 中正确的批次数始终是:

# number of batches
len(dl_train)