pytorch 1.10.1 中的自定义损失函数

Custom loss function in pytorch 1.10.1

我正在努力为 pytorch 1.10.1 定义自定义损失函数。我的模型输出一个从 -1 到 +1 的浮点数。目标值是任意范围的浮点数。如果模型输出和目标之间的符号不同,损失应该是产品的总和。

我在互联网上搜索了好几个小时,但似乎 pytorch 在最近的版本中发生了一些变化,所以我真的不知道哪个示例最适合我的用例和 pytorch 1.10。 1.

到目前为止,这是我的方法:

class Loss(torch.nn.Module):
    @staticmethod
    def forward(self, output, target) -> Tensor:
        loss = 0.0
        
        for i in range(len(target)):
            o = output[i,0]
            t = target[i]
            l = o * t
            if l<0:   #if different sign
                loss -= l

        return loss

问题:

  1. 我应该继承 torch.nn.Module 还是 torch.autograd.Function

  2. 我需要定义@staticmethod吗?

  3. 在某些示例中,我看到 ctx 而不是 self 被使用和调用 ctx.save_for_backward 等。我需要这个吗?它的目的是什么?

  4. 当子类化 torch.nn.Module 时,我的代码抱怨:'Tensor' 对象没有属性 'children'。我错过了什么?

  5. 当继承 torch.autograd.Function 时,我的代码抱怨没有定义后向函数。我的反向函数应该是什么样的?

自定义损失函数可以像 python 函数一样简单。你可以稍微简化一下:

def custom_loss(output, target):
    prod = output[:,0]*target
    return -prod[prod<0].sum()