在 Pytorch 内置的自定义 Batchnorm 中更新 running_mean 和 running_var 有问题吗?

Problem with updating running_mean and running_var in a custom Batchnorm built in Pytorch?

我一直在尝试实现自定义批归一化函数,以便它可以扩展到多 GPU 版本,特别是 Pytorch.The 自定义批归一化中的 DataParallel 模块在使用 1 个 GPU 时工作正常,但是,当扩展到2或更多时,运行均值和方差在前向函数中起作用,但当它returns从网络返回时,均值和方差重新初始化为0和1。

torch.nn.DataParallel 在警告部分提到“在每个转发中,模块在每个设备上复制,因此转发中对 运行 模块的任何更新都将丢失。例如,如果模块有一个在每次转发中递增的计数器属性,它将始终保持在初始值,因为更新是在转发后销毁的副本上完成的。”但我不太确定如何保留默认设备的均值和方差。

我已经提供了多 GPU 训练期间获得的结果的代码。此代码利用提供的 Batchnorm here.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from torch.nn.parameter import Parameter

class ptrblck_BatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(ptrblck_BatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3])
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input


class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.bn1 = ptrblck_BatchNorm2d(64)
        print("==> printing bn1 mean when init")
        print(self.bn1.running_mean)
        print("==> printing bn1 when init")
        print(self.bn1.running_mean)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.classifier = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.pool(x)
        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        print("======================================================")
        print("==> printing bn1 running mean from NET during forward")
        print(net.module.bn1.running_mean)
        print("==> printing bn1 running mean from SELF. during forward")
        print(self.bn1.running_mean)
        print("==> printing bn1 running var from NET during forward")
        print(net.module.bn1.running_var)
        print("==> printing bn1 running mean from SELF. during forward")
        print(self.bn1.running_var)
        return x

# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])


transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])


trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
net = net()
net = torch.nn.DataParallel(net).cuda()
print('Number of GPU {}'.format(torch.cuda.device_count()))

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.cuda(), targets.cuda()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        print("====================================================")
        print("==> printing bn1 running mean FROM net after forward")
        print(net.module.bn1.running_mean)
        print("==> printing bn1 running var FROM net after forward")
        print(net.module.bn1.running_var)

        break
        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()

        # train_loss += loss.item()
        # _, predicted = outputs.max(1)
        # total += targets.size(0)
        # correct += predicted.eq(targets).sum().item()

        # break


for epoch in range(0, 1):
    train(epoch)

结果:

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
==> Building model..
==> printing bn1 mean when init
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
==> printing bn1 when init
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
Number of GPU 2

Epoch: 0
======================================================
==> printing bn1 running mean from NET during forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([ 0.0053,  0.0010, -0.0077, -0.0290,  0.0241,  0.0258, -0.0048,  0.0151,
        -0.0133,  0.0080,  0.0197, -0.0042, -0.0188,  0.0233,  0.0310, -0.0230,
        -0.0133,  0.0222,  0.0119, -0.0042, -0.0220, -0.0169, -0.0342, -0.0025,
         0.0338, -0.0070,  0.0202,  0.0050,  0.0108,  0.0008,  0.0363,  0.0347,
        -0.0106,  0.0082,  0.0128,  0.0074,  0.0111, -0.0030, -0.0089,  0.0070,
        -0.0262, -0.0029,  0.0053, -0.0136, -0.0183,  0.0045, -0.0014, -0.0221,
         0.0132,  0.0064,  0.0388, -0.0220, -0.0008,  0.0400, -0.0187,  0.0397,
        -0.0131, -0.0176,  0.0035,  0.0055, -0.0270,  0.0066, -0.0149,  0.0135],
       device='cuda:0')
==> printing bn1 running var from NET during forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9665, 0.9073, 0.9220, 1.0947, 1.0687, 0.9624, 0.9252, 0.9131, 0.9066,
        0.9536, 0.9258, 0.9203, 1.0359, 0.9690, 1.1066, 1.0636, 0.9135, 0.9644,
        0.9373, 0.9846, 0.9696, 0.9454, 1.0459, 0.9245, 0.9778, 0.9709, 0.9352,
        0.9995, 0.9657, 0.9510, 1.0943, 1.0171, 0.9298, 1.0747, 0.9341, 0.9635,
        0.9978, 0.9303, 0.9261, 0.9137, 0.9569, 1.0066, 1.0463, 0.9955, 0.9621,
        0.9172, 0.9836, 0.9817, 0.9086, 0.9576, 1.0905, 0.9861, 0.9661, 1.1773,
        0.9345, 1.0904, 0.9133, 1.0660, 0.9164, 0.9058, 0.9446, 0.9225, 1.0914,
        0.9292], device='cuda:0')
======================================================
==> printing bn1 running mean from NET during forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([-0.0020,  0.0002, -0.0103, -0.0426,  0.0386,  0.0311, -0.0059,  0.0151,
        -0.0140,  0.0145,  0.0218, -0.0029, -0.0281,  0.0284,  0.0449, -0.0329,
        -0.0107,  0.0278,  0.0135, -0.0123, -0.0260, -0.0214, -0.0423, -0.0035,
         0.0410, -0.0097,  0.0276,  0.0102,  0.0197, -0.0001,  0.0483,  0.0451,
        -0.0078,  0.0190,  0.0135, -0.0004,  0.0196, -0.0028, -0.0140,  0.0070,
        -0.0332, -0.0110,  0.0151, -0.0210, -0.0226,  0.0074, -0.0088, -0.0314,
         0.0125, -0.0003,  0.0505, -0.0312,  0.0086,  0.0544, -0.0245,  0.0528,
        -0.0086, -0.0290,  0.0063,  0.0042, -0.0339,  0.0061, -0.0277,  0.0092],
       device='cuda:1')
==> printing bn1 running var from NET during forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')
==> printing bn1 running mean from SELF. during forward
tensor([0.9665, 0.9072, 0.9211, 1.0999, 1.0714, 0.9610, 0.9209, 0.9125, 0.9063,
        0.9553, 0.9260, 0.9189, 1.0386, 0.9706, 1.1139, 1.0610, 0.9121, 0.9660,
        0.9366, 0.9886, 0.9683, 0.9454, 1.0511, 0.9227, 0.9792, 0.9704, 0.9330,
        0.9989, 0.9657, 0.9476, 1.1008, 1.0191, 0.9294, 1.0814, 0.9320, 0.9642,
        1.0006, 0.9287, 0.9254, 0.9128, 0.9559, 1.0100, 1.0521, 0.9972, 0.9621,
        0.9168, 0.9849, 0.9803, 0.9083, 0.9556, 1.0946, 0.9865, 0.9651, 1.1880,
        0.9330, 1.0959, 0.9116, 1.0706, 0.9149, 0.9057, 0.9450, 0.9215, 1.0972,
        0.9261], device='cuda:1')
====================================================
==> printing bn1 running mean FROM net after forward
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
==> printing bn1 running var FROM net after forward
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')

如何确保使用 运行 估计的默认设备?目前,我没有致力于同步 Batchnorm。

正在替换

self.running_mean = (...)

self.running_mean.copy_(...)

完成任务。

Reference