在示例中 nn.Module 的实现中未覆盖 forward()
forward() not overridden in implementation of nn.Module in an example
在this示例中,我们看到nn.Module
的以下实现:
class Net(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
def decode(self, z, edge_label_index):
return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
def decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()
但是,在 docs 中我们有 'forward(*input)
'“应该被所有子类覆盖”。
为什么在这个例子中不是这样?
此 Net
模块旨在通过两个单独的接口 encoder
和 decode
使用,至少看起来是这样...因为它没有 forward
实现,那么是的,它不正确地继承自 nn.Module
。但是,该代码仍然“有效”,并且 运行 正确,但如果您使用正向挂钩,可能会有一些副作用。
对nn.Module
进行推理的标准方法是调用对象,即调用__call__
函数。这个 __call__
函数由父级 class nn.Module
实现,并依次做两件事:
- 在推理调用之前或之后处理前向挂钩
- 调用class的
forward
函数。
__call__
函数充当 forward
的包装器。
因此,出于这个原因,forward
函数预计会被用户定义的 nn.Module
覆盖。违反此设计模式的唯一警告是它将有效地忽略应用在 nn.Module
.
上的任何挂钩
在this示例中,我们看到nn.Module
的以下实现:
class Net(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
return self.conv2(x, edge_index)
def decode(self, z, edge_label_index):
return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
def decode_all(self, z):
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()
但是,在 docs 中我们有 'forward(*input)
'“应该被所有子类覆盖”。
为什么在这个例子中不是这样?
此 Net
模块旨在通过两个单独的接口 encoder
和 decode
使用,至少看起来是这样...因为它没有 forward
实现,那么是的,它不正确地继承自 nn.Module
。但是,该代码仍然“有效”,并且 运行 正确,但如果您使用正向挂钩,可能会有一些副作用。
对nn.Module
进行推理的标准方法是调用对象,即调用__call__
函数。这个 __call__
函数由父级 class nn.Module
实现,并依次做两件事:
- 在推理调用之前或之后处理前向挂钩
- 调用class的
forward
函数。
__call__
函数充当 forward
的包装器。
因此,出于这个原因,forward
函数预计会被用户定义的 nn.Module
覆盖。违反此设计模式的唯一警告是它将有效地忽略应用在 nn.Module
.