如何 return PyTorch 中模块前向函数的额外损失?

How to return extra loss from module forward function in PyTorch?

我制作了一个模块,它需要一个额外的损失项,例如

class MyModule:
  def forward(self, x):
    out = f(x)
    extra_loss = loss_f(self.parameters(), x)
    return out, extra_loss

我不知道如何使这个模块可嵌入,例如,嵌入到 Sequential 模型中:任何像 Linear 这样的常规模块放在这个模块之后都会失败,因为 extra_loss 导致 Linear 的输入成为元组,Linear 不支持。

所以我正在寻找的是在 运行 模型向前

之后提取额外的损失
my_module = MyModule()

model = Sequential(
  my_module,
  Linear(my_module_outputs, 1)
)

output = model(x)
my_module_loss = ????
loss = mse(label, output) + my_module_loss

模块可组合性是否支持这种情况?

你可以在这种情况下注册一个钩子。可以在 Tensor 或 nn.Module 上注册一个钩子。钩子是在调用 forward 或 backward 时执行的函数。在这种情况下,我们想要附加一个前向钩子而不从图中分离出来,以便可以发生向后传递。

import torch.nn as nn
act_out = {}
def get_hook(name):
    def hook(m, input, output):
        act_out[name] = output
    return hook

class MyModule(torch.nn.Module):
  def __init__(self, input, out, device=None):
    super().__init__()
    self.model = nn.Linear(input,out)
  def forward(self,x):
    return self.model(x), torch.sum(x) #our extra loss


class MyModule1(torch.nn.Module):
  def __init__(self, input, out, device=None):
    super().__init__()
    self.model = nn.Linear(input,out)
  def forward(self, pair):
    x, loss = pair
    return self.model(x)


model = nn.Sequential(
    MyModule(5,10),
    MyModule1(10,1)
)

for name, module in model.named_children():
    print(name, module)
    if name == '0':
        module.register_forward_hook(get_hook(name))

x = torch.tensor([1,2,3,4,5]).float()
out = model(x)

print(act_out)
loss = myanotherloss(out)+act_out['0'][1] # this is the extra loss
# further processing

注意:我使用的是 name == '0',因为这是我想要附加挂钩的唯一模块。

注意:另一个值得注意的地方是nn.Sequential 不允许多输入。在这种情况下,它被简单地视为一个元组,然后从该元组我们使用 lossinput.

恕我直言,这里的钩子反应过度了。如果 extra_loss 是可加的,我们可以像这样使用全局变量:

class MyModule:
    extra_loss =0
    def forward(self, x):
        out = f(x)
        MyModule.extra_loss += loss_f(self.parameters(), x)
        return out

output = model(x)
loss = mse(label, output) + MyModule.extra_loss
MyModule.extra_loss =0