如何在 torchvision 的 Inception V3 中将钩子附加到 ReLU
How to attach hooks to ReLUs in Inception V3 from torchvision
我正在使用 Inception v3 from torchvision。我试图在模型中找到 ReLU:
def recursively_find_submodules(model, submodule_type):
module_list = []
q = [model]
while q:
child = q.pop()
if isinstance(child, submodule_type):
module_list.append(child)
q.extend(list(child.children()))
return module_list
inception = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
l = recursively_find_submodules(inception, torch.nn.ReLU) # l is empty!
所以 ReLU 不是火炬模型中任何模块的子模块。仔细检查后,我在 torchvision 的源代码中发现了 ReLU,但没有作为模块。在 inception.py
中,我发现了以下内容:
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return F.relu(x, inplace=True)
所以 BasicConv2d
模块使用 ReLU 函数来限制它的输出而不是模块 (torch.nn.ReLU
)。我想如果不修改整个模型以使用 ReLU 模块,就无法连接到 ReLU 函数并修改它们的输入/输出,或者有没有办法做到这一点?
考虑到您观察的是 ReLU 的输入而不是激活后的特征,您可以挂钩到 ReLU 之前的批处理规范层并附加到那里。
我正在使用 Inception v3 from torchvision。我试图在模型中找到 ReLU:
def recursively_find_submodules(model, submodule_type):
module_list = []
q = [model]
while q:
child = q.pop()
if isinstance(child, submodule_type):
module_list.append(child)
q.extend(list(child.children()))
return module_list
inception = torch.hub.load('pytorch/vision:v0.6.0', 'inception_v3', pretrained=True)
l = recursively_find_submodules(inception, torch.nn.ReLU) # l is empty!
所以 ReLU 不是火炬模型中任何模块的子模块。仔细检查后,我在 torchvision 的源代码中发现了 ReLU,但没有作为模块。在 inception.py
中,我发现了以下内容:
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return F.relu(x, inplace=True)
所以 BasicConv2d
模块使用 ReLU 函数来限制它的输出而不是模块 (torch.nn.ReLU
)。我想如果不修改整个模型以使用 ReLU 模块,就无法连接到 ReLU 函数并修改它们的输入/输出,或者有没有办法做到这一点?
考虑到您观察的是 ReLU 的输入而不是激活后的特征,您可以挂钩到 ReLU 之前的批处理规范层并附加到那里。