我的 PyTorch 转发函数可以做额外的操作吗?
Can my PyTorch forward function do additional operations?
通常一个 forward
函数将一堆层和 returns 最后一层的输出串在一起。在返回之前,我可以在最后一层之后做一些额外的处理吗?例如,一些标量乘法和整形通过 .view
?
我知道 autograd 以某种方式计算出梯度。所以我不知道我的额外处理是否会以某种方式搞砸了。谢谢。
pytorch tracks the gradients via the computational graph of the tensors, not through the functions. As long as your tensors has requires_grad=True
property and their grad
不是 None
你可以(几乎)做任何你喜欢的事并且仍然可以反向传播。
只要您使用的是 pytorch 的操作(例如 here and here 中列出的操作),您就应该没问题。
有关详细信息,请参阅 this。
例如(取自torchvision's VGG implementation):
class VGG(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
# ...
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1) # <-- what you were asking about
x = self.classifier(x)
return x
更复杂的例子见torchvision's implementation of ResNet:
class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
# ...
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None: # <-- conditional execution!
identity = self.downsample(x)
out += identity # <-- inplace operations
out = self.relu(out)
return out
通常一个 forward
函数将一堆层和 returns 最后一层的输出串在一起。在返回之前,我可以在最后一层之后做一些额外的处理吗?例如,一些标量乘法和整形通过 .view
?
我知道 autograd 以某种方式计算出梯度。所以我不知道我的额外处理是否会以某种方式搞砸了。谢谢。
pytorch tracks the gradients via the computational graph of the tensors, not through the functions. As long as your tensors has requires_grad=True
property and their grad
不是 None
你可以(几乎)做任何你喜欢的事并且仍然可以反向传播。
只要您使用的是 pytorch 的操作(例如 here and here 中列出的操作),您就应该没问题。
有关详细信息,请参阅 this。
例如(取自torchvision's VGG implementation):
class VGG(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
# ...
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1) # <-- what you were asking about
x = self.classifier(x)
return x
更复杂的例子见torchvision's implementation of ResNet:
class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
# ...
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None: # <-- conditional execution!
identity = self.downsample(x)
out += identity # <-- inplace operations
out = self.relu(out)
return out