torchvision.transforms.Normalize() 添加到 torchvision 时会减慢学习速度,transforms.Compose()

torchvision.transforms.Normalize() slows down learning when adding to torchvision,transforms.Compose()

当我使用

train_transforms = torchvision.transforms.Compose([
  torchvision.transforms.ToTensor(), 
  torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

对于加载 MNIST 数据集,即使使用 mean = 0std = 1,它也会减慢学习速度。

转换是在 CPU 上执行的,mean/std 是否全为零并不重要(顺便说一句,不要将 std 设置为 0)。要加快转换速度,您有两个选择:

  1. 如果您的流程中没有任何数据扩充,只需转换数据并将其保存为标准化张量(腌制或其他)。
  2. 您还可以使用带有一些参数的 torch.utils.data.DataLoader:例如 num_workers 指定要使用多少 CPU 个进程来转换数据。如果您使用的是 CUDA,还有 pin_memory 可以加快整个过程。