我是否需要加载我在 NN class 中使用的另一个 class 的权重?
Do I need to load the weights of another class I use in my NN class?
我有一个模型需要实现自注意力,我是这样写代码的:
class SelfAttention(nn.Module):
def __init__(self, args):
self.multihead_attn = torch.nn.MultiheadAttention(args)
def foward(self, x):
return self.multihead_attn.forward(x, x, x)
class ActualModel(nn.Module):
def __init__(self):
self.inp_layer = nn.Linear(arg1, arg2)
self.self_attention = SelfAttention(some_args)
self.out_layer = nn.Linear(arg2, 1)
def forward(self, x):
x = self.inp_layer(x)
x = self.self_attention(x)
x = self.out_layer(x)
return x
加载 ActualModel
的检查点后,在 ActualModel.__init__
期间继续训练期间或预测期间,我是否应该加载 class SelfAttention
的已保存模型检查点?
如果我创建了一个 class SelfAttention
的实例,如果我这样做 torch.load(actual_model.pth)
对应于 SelfAttention.multihead_attn
的训练权重会被加载还是会被重新初始化?
也就是说,有这个必要吗?
class ActualModel(nn.Module):
def __init__(self):
self.inp_layer = nn.Linear(arg1, arg2)
self.self_attention = SelfAttention(some_args)
self.out_layer = nn.Linear(arg2, 1)
def pred_or_continue_train(self):
self.self_attention = torch.load('self_attention.pth')
actual_model = torch.load('actual_model.pth')
actual_model.pred_or_continue_training()
actual_model.eval()
In other words, is this necessary?
简而言之,没有。
SelfAttention
class如果已经注册为nn.module、nn.Parameters或手动注册的缓冲区,将自动加载。
一个简单的例子:
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, fin, n_h):
super(SelfAttention, self).__init__()
self.multihead_attn = torch.nn.MultiheadAttention(fin, n_h)
def foward(self, x):
return self.multihead_attn.forward(x, x, x)
class ActualModel(nn.Module):
def __init__(self):
super(ActualModel, self).__init__()
self.inp_layer = nn.Linear(10, 20)
self.self_attention = SelfAttention(20, 1)
self.out_layer = nn.Linear(20, 1)
def forward(self, x):
x = self.inp_layer(x)
x = self.self_attention(x)
x = self.out_layer(x)
return x
m = ActualModel()
for k, v in m.named_parameters():
print(k)
你会得到如下信息,其中self_attention
注册成功
inp_layer.weight
inp_layer.bias
self_attention.multihead_attn.in_proj_weight
self_attention.multihead_attn.in_proj_bias
self_attention.multihead_attn.out_proj.weight
self_attention.multihead_attn.out_proj.bias
out_layer.weight
out_layer.bias
我有一个模型需要实现自注意力,我是这样写代码的:
class SelfAttention(nn.Module):
def __init__(self, args):
self.multihead_attn = torch.nn.MultiheadAttention(args)
def foward(self, x):
return self.multihead_attn.forward(x, x, x)
class ActualModel(nn.Module):
def __init__(self):
self.inp_layer = nn.Linear(arg1, arg2)
self.self_attention = SelfAttention(some_args)
self.out_layer = nn.Linear(arg2, 1)
def forward(self, x):
x = self.inp_layer(x)
x = self.self_attention(x)
x = self.out_layer(x)
return x
加载 ActualModel
的检查点后,在 ActualModel.__init__
期间继续训练期间或预测期间,我是否应该加载 class SelfAttention
的已保存模型检查点?
如果我创建了一个 class SelfAttention
的实例,如果我这样做 torch.load(actual_model.pth)
对应于 SelfAttention.multihead_attn
的训练权重会被加载还是会被重新初始化?
也就是说,有这个必要吗?
class ActualModel(nn.Module):
def __init__(self):
self.inp_layer = nn.Linear(arg1, arg2)
self.self_attention = SelfAttention(some_args)
self.out_layer = nn.Linear(arg2, 1)
def pred_or_continue_train(self):
self.self_attention = torch.load('self_attention.pth')
actual_model = torch.load('actual_model.pth')
actual_model.pred_or_continue_training()
actual_model.eval()
In other words, is this necessary?
简而言之,没有。
SelfAttention
class如果已经注册为nn.module、nn.Parameters或手动注册的缓冲区,将自动加载。
一个简单的例子:
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, fin, n_h):
super(SelfAttention, self).__init__()
self.multihead_attn = torch.nn.MultiheadAttention(fin, n_h)
def foward(self, x):
return self.multihead_attn.forward(x, x, x)
class ActualModel(nn.Module):
def __init__(self):
super(ActualModel, self).__init__()
self.inp_layer = nn.Linear(10, 20)
self.self_attention = SelfAttention(20, 1)
self.out_layer = nn.Linear(20, 1)
def forward(self, x):
x = self.inp_layer(x)
x = self.self_attention(x)
x = self.out_layer(x)
return x
m = ActualModel()
for k, v in m.named_parameters():
print(k)
你会得到如下信息,其中self_attention
注册成功
inp_layer.weight
inp_layer.bias
self_attention.multihead_attn.in_proj_weight
self_attention.multihead_attn.in_proj_bias
self_attention.multihead_attn.out_proj.weight
self_attention.multihead_attn.out_proj.bias
out_layer.weight
out_layer.bias