手动初始化pytorch中的模型参数
Initializing model parameters in pytorch manually
我正在创建一个单独的 class 来初始化模型并在列表中添加层,但是这些层没有被添加到参数中,请告诉我如何将它们添加到模型的 parameters() 中。
class Mnist_Net(nn.Module):
def __init__(self,input_dim,output_dim,hidden_layers=2,neurons=128):
super().__init__()
layers = []
for i in range(hidden_layers):
if len(layers) == 0:
layers.append(nn.Linear(input_dim,neurons))
if i == hidden_layers-1:
layers.append(nn.Linear(layers[-2].weight.shape[0],output_dim))
layers.append(nn.Linear(layers[i-1].weight.shape[0],neurons))
self.layers= layers
当我打印 model.parameters()
model = Mnist_Net(28*28,10,neurons=56)
for t in model.parameters():
print(t)
它什么也没显示,但是当我在 class 中添加图层时,比如
self.layer1 = nn.Linear(input_dim,neurons)
它显示 parameters.Plz 中的一层告诉我如何在 model.parameters()
中添加 self.layers 中的所有层
要在父模块中注册,您的子模块本身应该是 nn.Module
。在你的情况下,你应该用 nn.ModuleList
:
包装 layers
self.layers = nn.ModuleList(layers)
然后,您的图层将被注册:
>>> model = Mnist_Net(28*28,10, neurons=56)
>>> for t in model.parameters():
... print(t.shape)
torch.Size([56, 784])
torch.Size([56])
torch.Size([56, 56])
torch.Size([56])
torch.Size([10, 56])
torch.Size([10])
torch.Size([56, 56])
torch.Size([56])
我正在创建一个单独的 class 来初始化模型并在列表中添加层,但是这些层没有被添加到参数中,请告诉我如何将它们添加到模型的 parameters() 中。
class Mnist_Net(nn.Module):
def __init__(self,input_dim,output_dim,hidden_layers=2,neurons=128):
super().__init__()
layers = []
for i in range(hidden_layers):
if len(layers) == 0:
layers.append(nn.Linear(input_dim,neurons))
if i == hidden_layers-1:
layers.append(nn.Linear(layers[-2].weight.shape[0],output_dim))
layers.append(nn.Linear(layers[i-1].weight.shape[0],neurons))
self.layers= layers
当我打印 model.parameters()
model = Mnist_Net(28*28,10,neurons=56)
for t in model.parameters():
print(t)
它什么也没显示,但是当我在 class 中添加图层时,比如
self.layer1 = nn.Linear(input_dim,neurons)
它显示 parameters.Plz 中的一层告诉我如何在 model.parameters()
中添加 self.layers 中的所有层要在父模块中注册,您的子模块本身应该是 nn.Module
。在你的情况下,你应该用 nn.ModuleList
:
layers
self.layers = nn.ModuleList(layers)
然后,您的图层将被注册:
>>> model = Mnist_Net(28*28,10, neurons=56)
>>> for t in model.parameters():
... print(t.shape)
torch.Size([56, 784])
torch.Size([56])
torch.Size([56, 56])
torch.Size([56])
torch.Size([10, 56])
torch.Size([10])
torch.Size([56, 56])
torch.Size([56])