单元测试pytorch前向函数

Unittest the pytorch forward function

我想在 Pytorch 中对我的网络模型的覆盖前向函数进行单元测试。所以我用 setUp 方法加载了我的模型(从 Zoo 预训练),加载了一个种子并创建了一些随机批次。在我的方法 testForward 中,我测试了 forward 与 shape 和 numel 的结果,但我也想检查一个特定的值,该值似乎为 0。我不确定这一点,所以也在 setUp 中检查了我的参数,这似乎不是为 0.

import unittest
import torch
from SemanticSegmentation.models.fcn8 import FCN8


class TestFCN8(unittest.TestCase):

    def setUp(self):
        self.model = FCN8(8, pretrained=True)
        torch.manual_seed(0)
        self.x = torch.rand((4, 3, 45, 45))
        for param in self.model.parameters():
            print(param.data)

    def testForward(self):
        self.assertEqual(self.model.forward(self.x).shape.numel(), 64800)
        self.assertEqual(str(self.model.forward(self.x).shape), 'torch.Size([4, 8, 45, 45])')
        print(self.model.named_parameters)


if __name__ == "__main__":
    unittest.main()

所以我的问题是:正向 return 张量的 sahpe 是我所期望的,但为什么这个张量完全为零?我预计至少有几个值。

导入的模型基于VGG16网络,并在ConvLayer 4、8和16之后进行了升级。如果需要,我也可以提供模型代码。

好的,在修改和调试前向函数后,我得出以下解释:

关于架构的一些信息

如果你从 Andrew Ng 或其他人那里做 类,你会学到不要将权重初始化为相同的值,例如“0”。这就是 FCN 原始论文的作者所做的和他们所说的,因为它不会改变性能或不会屈服于更快的收敛 (FCN-Paper)。

我的解决方案

因此,出于测试目的,我在测试模块中初始化以播种随机值,我可以针对这些值进行测试:

import unittest
import torch
from SemanticSegmentation.models.fcn8 import FCN8


class TestFCN8(unittest.TestCase):

    def setUp(self):
        self.model = FCN8(8, pretrained=True)
        torch.manual_seed(0)
        # instead of zero init for score tensors use random init
        self.model.score_fr[6].weight.data.random_()
        self.model.score_fr[6].bias.data.random_()
        self.model.score_pool3.weight.data.random_()
        self.model.score_pool3.bias.data.random_()
        self.model.score_pool4.weight.data.random_()
        self.model.score_pool4.bias.data.random_()
        self.x = torch.rand((4, 3, 45, 45))

    def testForward(self):
        self.assertEqual(
            self.model.forward(self.x).shape.numel(), 64800)
        self.assertEqual(
            list(self.model.forward(self.x).shape), [4, 8, 45, 45])
        self.assertEqual(
            float(self.model.forward(self.x)[3][4][44][4]), 2277257216.0))

if __name__ == "__main__":
    unittest.main()