如何使用pytorch遍历神经网络中的所有参数

How to iterate through all parameters in a neural network using pytorch

我有以下简单的全连接神经网络:

class Neural_net(nn.Module):
    def __init__(self):
        super(Neural_net, self).__init__()
        self.fc1    = nn.Linear(2, 2)        
        self.fc2    = nn.Linear(2, 1)
        self.fc_out = nn.Linear(1, 1)      
        
    def forward(self, x,train = True):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = self.fc_out(x)
        return x

net = Neural_net()

如何遍历网络的所有参数并检查它们是否大于某个值?我正在使用 pytorch,如果我这样做:

for n,p in net.named_parameters():     
       if p > value:
       ...

我收到一个错误,因为 p 不是单个数字,而是每一层的权重或偏差的张量。

我的目标是检查每个参数是否满足标准并标记它们,例如如果是 1 或者如果不是 0 ,将其存储在与 net.parameters() 具有相同结构的字典中。然而,我无法弄清楚如何遍历它们。

我想过创建一个参数向量:

param_vec  =  torch.cat([p.view(-1) for p in net.parameters()])

然后访问参数值并检查它们会很容易,但是我想不出一种方法可以返回到字典形式来标记它们。

感谢您的帮助!

首先,我将标准定义为对张量的运算。在您的情况下,这可能如下所示:

cond = lambda tensor: tensor.gt(value)

然后你只需要将它应用于net.parameters()中的每个张量。为了保持相同的结构,你可以用 dict 理解来做到这一点:

cond_parameters = {n: cond(p) for n,p in net.named_parameters()}

让我们在实践中看看吧!

net = Neural_net()
print(dict(net.parameters())
#> {'fc1.weight': Parameter containing:
#>  tensor([[-0.4767,  0.0771],
#>          [ 0.2874,  0.5474]], requires_grad=True),
#>  'fc1.bias': Parameter containing:
#>  tensor([ 0.0405, -0.1997], requires_grad=True),
#>  'fc2.weight': Parameter containing:
#>  tensor([[0.5400, 0.3241]], requires_grad=True),
#>  'fc2.bias': Parameter containing:
#>  tensor([-0.5306], requires_grad=True),
#>  'fc_out.weight': Parameter containing:
#>  tensor([[-0.9706]], requires_grad=True),
#>  'fc_out.bias': Parameter containing:
#> tensor([-0.4174], requires_grad=True)}

让我们将 value 设置为零并获取参数字典:

value = 0
cond = lambda tensor: tensor.gt(value)
cond_parameters = {n: cond(p) for n,p in net.named_parameters()}
#>{'fc1.weight': tensor([[False,  True],
#>         [ True,  True]]),
#> 'fc1.bias': tensor([ True, False]),
#> 'fc2.weight': tensor([[True, True]]),
#> 'fc2.bias': tensor([False]),
#> 'fc_out.weight': tensor([[False]]),
#> 'fc_out.bias': tensor([False])}