Transformer中target的Pytorch NLP序列长度

Pytorch NLP sequence length of target in Transformer

我正在尝试理解 Transformer (https://github.com/SamLynnEvans/Transformer) 的代码。

如果在“train”脚本中看到 train_model 函数,我想知道为什么需要使用与 trg:

不同的 trg_input 序列长度
trg_input = trg[:, :-1]

在这种情况下,trg_input的序列长度是“seq_len(trg) - 1”。 这意味着 trg 就像:

<sos> tok1 tok2 tokn <eos>

和 trg_input 就像:

<sos> tok1 tok2 tokn    (no eos token)

请告诉我原因。

谢谢。

相关代码如下:

    for i, batch in enumerate(opt.train):
        src = batch.src.transpose(0, 1).to('cuda')
        trg = batch.trg.transpose(0, 1).to('cuda')

        trg_input = trg[:, :-1]
        src_mask, trg_mask = create_masks(src, trg_input, opt)
        preds = model(src, trg_input, src_mask, trg_mask)
        ys = trg[:, 1:].contiguous().view(-1)
        opt.optimizer.zero_grad()
        loss = F.cross_entropy(preds.view(-1, preds.size(-1)), ys, ignore_index=opt.trg_pad)
        loss.backward()
        opt.optimizer.step()


def create_masks(src, trg, opt):
    
    src_mask = (src != opt.src_pad).unsqueeze(-2)

    if trg is not None:
        trg_mask = (trg != opt.trg_pad).unsqueeze(-2)
        size = trg.size(1) # get seq_len for matrix
        np_mask = nopeak_mask(size, opt)
        if trg.is_cuda:
            np_mask.cuda()
        trg_mask = trg_mask & np_mask
        
    else:
        trg_mask = None
    return src_mask, trg_mask

那是因为整个目标是根据我们目前看到的令牌生成下一个令牌。当我们得到预测时,看看模型的输入。我们不仅提供源序列,还提供目标序列 ,直到我们当前的步骤 Models.py 里面的模型看起来像:

class Transformer(nn.Module):
    def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout):
        super().__init__()
        self.encoder = Encoder(src_vocab, d_model, N, heads, dropout)
        self.decoder = Decoder(trg_vocab, d_model, N, heads, dropout)
        self.out = nn.Linear(d_model, trg_vocab)
    def forward(self, src, trg, src_mask, trg_mask):
        e_outputs = self.encoder(src, src_mask)
        #print("DECODER")
        d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
        output = self.out(d_output)
        return output

所以你可以看到 forward 方法接收 srctrg,它们分别被送入编码器和解码器。如果您从 the original paper:

查看模型架构,这会更容易理解

“输出(右移)”对应代码中的trg[:, :-1]