带有 class 继承的 pytorch 就地运行时错误

pytorch inplace runtime error with class inheritance

我可以知道为什么 this forward() function gives runtime errorinplace 操作上吗?

注意:我进行了一些代码调试,导致以下代码行:

class ConvEdge(Edge):    
    def __init__(self, stride): 
        super().__init__()        
        self.f = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=(3, 3), stride=(stride, stride), padding=1)

如果你们看一下有关 class ConvEdge(Edge) 的代码片段,我实际上正在重新考虑如何 inheritance 正在被 pytorch autograd 库查看和处理。

大家怎么看?

问题已使用 with torch.no_grad() 解决,在不需要时基本上不会传播梯度。