在示例中 nn.Module 的实现中未覆盖 forward()

forward() not overridden in implementation of nn.Module in an example

this示例中,我们看到nn.Module的以下实现:

class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

但是,在 docs 中我们有 'forward(*input)'“应该被所有子类覆盖”。

为什么在这个例子中不是这样?

Net 模块旨在通过两个单独的接口 encoderdecode 使用,至少看起来是这样...因为它没有 forward 实现,那么是的,它不正确地继承自 nn.Module。但是,该代码仍然“有效”,并且 运行 正确,但如果您使用正向挂钩,可能会有一些副作用。

nn.Module进行推理的标准方法是调用对象,调用__call__ 函数。这个 __call__ 函数由父级 class nn.Module 实现,并依次做两件事:

  • 在推理调用之前或之后处理前向挂钩
  • 调用class的forward函数。

__call__ 函数充当 forward 的包装器。 因此,出于这个原因,forward 函数预计会被用户定义的 nn.Module 覆盖。违反此设计模式的唯一警告是它将有效地忽略应用在 nn.Module.

上的任何挂钩