使用 torchviz 显示具有多个输出的 PyTorch 模型 make_dots
Display PyTorch model with multiple outputs using torchviz make_dots
我有一个具有多个输出的模型,确切地说是 4 个:
def forward(self, x):
outputs = []
for conv, act in zip(self.Convolutions, self.Activations):
y = conv(x)
outputs.append(act(y))
return outputs
我想使用 torchviz
中的 make_dot
显示它:
from torchviz import make_dot
generator = ...
batch = next(iter(generator))
input, output = batch["input"].to(device, dtype=torch.float), batch["output"].to(device, dtype=torch.float)
dot = make_dot(model(input), params=dict(model.named_parameters()))
但是我得到以下错误:
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/torchviz/dot.py", line 37, in make_dot
output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)
AttributeError: 'list' object has no attribute 'grad_fn'
显然列表没有grad_fn
函数,但是根据this discussion,我可以return输出列表。
我做错了什么?
模型可以 return 一个列表,但是 make_dot
想要一个 Tensor
。如果输出组件具有相似的形状,我建议在其上使用 torch.cat
。
我有一个具有多个输出的模型,确切地说是 4 个:
def forward(self, x):
outputs = []
for conv, act in zip(self.Convolutions, self.Activations):
y = conv(x)
outputs.append(act(y))
return outputs
我想使用 torchviz
中的 make_dot
显示它:
from torchviz import make_dot
generator = ...
batch = next(iter(generator))
input, output = batch["input"].to(device, dtype=torch.float), batch["output"].to(device, dtype=torch.float)
dot = make_dot(model(input), params=dict(model.named_parameters()))
但是我得到以下错误:
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/torchviz/dot.py", line 37, in make_dot
output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)
AttributeError: 'list' object has no attribute 'grad_fn'
显然列表没有grad_fn
函数,但是根据this discussion,我可以return输出列表。
我做错了什么?
模型可以 return 一个列表,但是 make_dot
想要一个 Tensor
。如果输出组件具有相似的形状,我建议在其上使用 torch.cat
。