如何在 Pytorch 训练期间添加额外的输出节点?

How to add an additional output node during training for Pytorch?

我正在制作一个 class-增量学习多标签 classifier。这里的模型首先使用 7 个标签进行训练。训练后,出现另一个包含相同标签的数据集,但多了一个标签。我想自动向经过训练的网络添加一个额外的节点,并继续在这个新数据集上进行训练。我该怎么做?

class FeedForewardNN(nn.Module):
    def __init__(self, input_size, h1_size = 264, h2_size = 128, num_services=8):
        super().__init__()
        self.input_size = input_size
        self.lin1 = nn.Linear(input_size, h1_size)
        self.lin2 = nn.Linear(h1_size, h2_size)
        self.lin3 = nn.Linear(h2_size, num_services)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.lin1(x)
        x = self.relu(x)
        x = self.lin2(x)
        x = self.relu(x)
        x = self.lin3(x)
        x = self.sigmoid(x)
        return x

这是前馈神经网络的架构。 然后我首先在只有 7 classes.

的数据集上训练
#Create NN
input_size = len(x_columns)
net1 = FeedForewardNN(input_size, num_services=7)
alpha= 0.001

#Define optimizer
optimizer = optim.Adam(net.parameters(), lr=alpha)
criterion = nn.BCELoss()
running_loss = 0

#Training Loop
loss_list = []
auc_list = []

for i in range(len(train_data_x)):
    optimizer.zero_grad()

    outputs = net1(train_data_x[i])
    loss = criterion(outputs, train_data_y[i])
    loss.backward()
    optimizer.step()

然而,我想添加一个额外的输出节点,定义新的权重但保持旧的训练权重,并在这个新数据集上训练。

我建议用具有所需形状的新图层替换图层,然后将其参数值部分分配给旧参数值,如下所示:

def increaseClassifier( m: torch.nn.Linear ):
    w = m.weight
    b = m.bias
    old_shape = m.weight.shape

    m2 = nn.Linear( old_shape[1], old_shape[0] +1 )
    m2.weight = nn.parameter.Parameter( torch.cat( (m.weight, m2.weight[0:1]) ) )
    m2.bias = nn.parameter.Parameter( torch.cat( (m.bias, m2.bias[0:1]) ) )
    return m2

class FeedForewardNN(nn.Module):
    ...
    def incrHere(self):
        self.lin3 = increaseClassifier( self.lin3 )

更新:

Can you explain, how these additional weights that come with this new output node are initialized?

新通道的初始权重来自新层的创建,层构造器通过一些随机初始化来生成新参数,然后我们用训练好的权重替换其中的一部分,剩下的部分准备好进行新的训练。

m2.weight = nn.parameter.Parameter( torch.cat( (m.weight, m2.weight[0:1] ) ) )