在pytorch中,什么情况下损失函数需要继承nn.module?
In pytorch, what situations the loss function need to inherit nn.module?
我对 PyTorch 中的损失函数感到困惑。有些人将损失函数定义为一个普通的python函数,而另一些人则通过定义一个继承nn.Module的class来定义损失函数。所以我想知道什么情况下我们需要通过继承nn.Module来定义损失函数?非常感谢。
一般来说,只有当你想在这个模块中有可训练的变量时,才需要从nn.Module
继承,否则继承它是可选的。
同样适用于损失函数,如果它不包含此类变量(我认为这是主要情况),则不需要继承。
我对 PyTorch 中的损失函数感到困惑。有些人将损失函数定义为一个普通的python函数,而另一些人则通过定义一个继承nn.Module的class来定义损失函数。所以我想知道什么情况下我们需要通过继承nn.Module来定义损失函数?非常感谢。
一般来说,只有当你想在这个模块中有可训练的变量时,才需要从nn.Module
继承,否则继承它是可选的。
同样适用于损失函数,如果它不包含此类变量(我认为这是主要情况),则不需要继承。