我是否需要加载我在 NN class 中使用的另一个 class 的权重?

Do I need to load the weights of another class I use in my NN class?

我有一个模型需要实现自注意力,我是这样写代码的:

class SelfAttention(nn.Module):
    def __init__(self, args):
        self.multihead_attn = torch.nn.MultiheadAttention(args)
        
    def foward(self, x):
        return self.multihead_attn.forward(x, x, x)
    
class ActualModel(nn.Module):
    def __init__(self):
        self.inp_layer = nn.Linear(arg1, arg2)
        self.self_attention = SelfAttention(some_args)
        self.out_layer = nn.Linear(arg2, 1)
    
    def forward(self, x):
        x = self.inp_layer(x)
        x = self.self_attention(x)
        x = self.out_layer(x)
        return x

加载 ActualModel 的检查点后,在 ActualModel.__init__ 期间继续训练期间或预测期间,我是否应该加载 class SelfAttention 的已保存模型检查点?

如果我创建了一个 class SelfAttention 的实例,如果我这样做 torch.load(actual_model.pth) 对应于 SelfAttention.multihead_attn 的训练权重会被加载还是会被重新初始化?

也就是说,有这个必要吗?

class ActualModel(nn.Module):
    
    def __init__(self):
        self.inp_layer = nn.Linear(arg1, arg2)
        self.self_attention = SelfAttention(some_args)
        self.out_layer = nn.Linear(arg2, 1)
        
    def pred_or_continue_train(self):
        self.self_attention = torch.load('self_attention.pth')

actual_model = torch.load('actual_model.pth')
actual_model.pred_or_continue_training()
actual_model.eval()

In other words, is this necessary?

简而言之,没有

SelfAttentionclass如果已经注册为nn.module、nn.Parameters或手动注册的缓冲区,将自动加载。

一个简单的例子:

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, fin, n_h):
        super(SelfAttention, self).__init__()
        self.multihead_attn = torch.nn.MultiheadAttention(fin, n_h)
        
    def foward(self, x):
        return self.multihead_attn.forward(x, x, x)
    
class ActualModel(nn.Module):
    def __init__(self):
        super(ActualModel, self).__init__()
        self.inp_layer = nn.Linear(10, 20)
        self.self_attention = SelfAttention(20, 1)
        self.out_layer = nn.Linear(20, 1)
    
    def forward(self, x):
        x = self.inp_layer(x)
        x = self.self_attention(x)
        x = self.out_layer(x)
        return x

m = ActualModel()
for k, v in m.named_parameters():
    print(k)

你会得到如下信息,其中self_attention注册成功

inp_layer.weight
inp_layer.bias
self_attention.multihead_attn.in_proj_weight
self_attention.multihead_attn.in_proj_bias
self_attention.multihead_attn.out_proj.weight
self_attention.multihead_attn.out_proj.bias
out_layer.weight
out_layer.bias