在pytorch中提取MiDaS神经网络的中间表示?

Extract intermediate representation of MiDaS neural network in pytorch?

Pytorch文档提供了concise way应用MiDaS单目深度估计网络进行深度提取。但是我应该如何修改他们的代码以在某个中间层提取网络表示?我知道我可以从 github 下载模型并将 forward 函数修改为 return 我想要的,但我对最简单的解决方案感兴趣,将外部代码保持原样。

我知道子class模型 class 和编写我自己的正向函数,比如 here,但我不知道如何访问 class 在代码中。使用 midas = torch.hub.load("intel-isl/MiDaS", model_type) 立即创建模型实例。也许使用前向钩子的例子会更容易。

如您所说,在 nn.Module 上使用正向钩子是最简单的方法。考虑文档:https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook

基本上,您只需定义一个接受三个输入的函数 (module, input, output),然后根据您的需要对这些数据执行任何操作。要找到您要放置该挂钩的模块,您显然需要熟悉模型的结构。您只需 print(midas) 即可获得所有可用模块的精美打印表示。我只是随机选择了一个,并使用 print() 函数作为一个钩子:

midas.pretrained.model.blocks[3].mlp.fc2.register_forward_hook(print)

这意味着每当我们调用 midas(some_input) 时,钩子(在本例中为 print)将使用相应的参数被调用。当然,您可以编写一个将这些文件保存到例如 print 的函数,而不是 print您可以从外部访问的列表,或将它们写入文件等。