如何获取 CNN 模型的输入和输出特征图?

How to get the input and output feature maps of a CNN model?

我试图找到图像在每一层通过卷积神经网络时的尺寸。因此,例如,如果应用了最大池化或卷积,我想知道该层所有层的图像形状。我知道我可以使用 nOut=image+2p-f / s + 1 公式,但考虑到 PyTorch 模型的大小,它会过于乏味和复杂。有没有一种简单的方法可以做到这一点,也许是可视化 tool/script 之类的?

访问https://deeplearning.neuromatch.io/tutorials/W2D1_ConvnetsAndRecurrentNeuralNetworks/student/W2D1_Tutorial1.html 它是一个免费教程,包含神经匹配学院提供的 CNN 的广泛视觉表示。

您可以使用 torchinfo 库:https://github.com/TylerYep/torchinfo

让我们以他们为例:

from torchinfo import summary

model = ConvNet()
batch_size = 16
summary(model, input_size=(batch_size, 1, 28, 28))

这里(1, 28, 28)是输入的大小,分别是图像的(Channel, Width, Height)

图书馆将打印:

================================================================================================================
Layer (type:depth-idx)          Input Shape          Output Shape         Param #            Mult-Adds
================================================================================================================
SingleInputNet                  --                   --                   --                  --
├─Conv2d: 1-1                   [7, 1, 28, 28]       [7, 10, 24, 24]      260                1,048,320
├─Conv2d: 1-2                   [7, 10, 12, 12]      [7, 20, 8, 8]        5,020              2,248,960
├─Dropout2d: 1-3                [7, 20, 8, 8]        [7, 20, 8, 8]        --                 --
├─Linear: 1-4                   [7, 320]             [7, 50]              16,050             112,350
├─Linear: 1-5                   [7, 50]              [7, 10]              510                3,570
================================================================================================================
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 3.41
================================================================================================================
Input size (MB): 0.02
Forward/backward pass size (MB): 0.40
Params size (MB): 0.09
Estimated Total Size (MB): 0.51
================================================================================================================

我认为 7 在这个输出中是错误的。应该是 16.