torch.nn.Sequential 的设计块在提供输入时存在问题
torch.nn.Sequential of designed blocks problem in giving inputs
我设计了一个class,它是一个网络块,它的前向有三个输入:x、logdet、reverse,并有两个输出。 例如,当我调用此 class 并使用它时一切正常,例如:
x = torch.Tensor(np.random.rand(2, 48, 8, 8))
net = Block(inp = 48, oup = 48, mid_channels=48, ksize=3, stride=1, group = 3)
a, _ = net(x, reverse = False)
但是当我想按Sequential来使用的时候(因为我需要一个接一个的多块),问题是这样的:
x = torch.Tensor(np.random.rand(2, 48, 8, 8))
conv1_network = nn.Sequential(
Block(inp = 48, oup = 48, mid_channels=48, ksize=3, stride=1, group = 3)
)
conv1_network(x, reverse = False)
我的错误是:
TypeError: forward() got an unexpected keyword argument 'reverse'
这是不正常的,因为正如我们在第一部分中看到的那样,我在 Block 中的正向输入中有反向。
我期待找到一种方法将一些块相互连接,例如这是一个块
class Block(nn.Module):
def __init__(self, num_channels):
super(InvConv, self).__init__()
self.num_channels = num_channels
# Initialize with a random orthogonal matrix
w_init = np.random.randn(num_channels, num_channels)
w_init = np.linalg.qr(w_init)[0].astype(np.float32)
self.weight = nn.Parameter(torch.from_numpy(w_init))
def forward(self, x, logdet, reverse=False):
ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)
if reverse:
weight = torch.inverse(self.weight.double()).float()
logdet = logdet - ldj
else:
weight = self.weight
logdet = logdet + ldj
weight = weight.view(self.num_channels, self.num_channels, 1, 1)
z = F.conv2d(x, weight)
return z, logdet
而我的目的是在一个for中将多个Block以Sequential的方式相互连接(因为我不能在我的工作中使用相同的Block,所以我需要不同的卷积来制作深度网络)
features = []
for i in range(10):
self.features.append(Block(num_channels = 48))
然后我想像这样使用它们
self.features(x, logdet = 0, reverse = False)
您表示您的 Block
nn.Module
有一个 reverse
选项。但是 nn.Sequential 没有,所以 conv1_network(x, reverse=False)
无效,因为 conv1_network
不是 Block
.
默认情况下,您无法将 kwargs
传递给 nn.Sequential
内的图层。但是,您可以继承 nn.Sequential
并自己完成。类似于:
class BlockSequence(nn.Sequential):
def forward(self, input, **kwargs):
for module in self:
options = kwargs if isinstance(module, Block) else {}
input = module(input, **options)
return input
这样,您可以创建一个包含 Block
的序列(以及可选的非 Block
模块):
>>> blocks = []
>>> for i in range(10):
... self.blocks.append(Block(num_channels=48))
>>> blocks = BlockSequence(*blocks)
然后您将能够使用 reverse
关键字参数调用 blocks
,调用时会转发给每个潜在的 Block
子模块:
>>> blocks(x, logdet=0, reverse=False)