屏蔽:屏蔽指定令牌后的所有内容(eos)

Masking: Mask everything after a specified token (eos)

我的 tgt 张量的形状是 [12, 32, 1],即 sequence_length, batch_size, token_idx

创建掩码的最佳方法是什么?<eos> 和之前的条目顺序为 1,之后为 0?

目前我正在这样计算我的掩码,它只是将零放在 <blank> 所在的位置,否则为零。

mask = torch.zeros_like(tgt).masked_scatter_((tgt != tgt_padding), torch.ones_like(tgt))

但问题是,我的 tgt 也可以包含 <blank>(在 <eos> 之前),在这种情况下我不想将其屏蔽掉。

我的临时解决方案:

mask = torch.ones_like(tgt)
for eos_token in (tgt == tgt_eos).nonzero():
    mask[eos_token[0]+1:,eos_token[1]] = 0

我猜您正在尝试为 PAD 令牌创建掩码。有几种方法。其中之一如下

# tensor is of shape [seq_len, batch_size, 1]
tensor = tensor.mul(tensor.ne(PAD).float())

这里,PAD代表PAD_TOKEN的索引。 tensor.ne(PAD) 将创建一个字节张量,其中在 PAD_TOKEN 位置,将分配 0,在其他位置分配 1。


如果你有这样的例子,"<s> I think <pad> so </s> <pad> <pad>"。然后,我建议在 </s> 之前和之后使用不同的 PAD 令牌。

或者,如果你有每个句子的长度信息(在上面的例子中,句子长度是6),那么你可以使用下面的函数创建掩码。

def sequence_mask(lengths, max_len=None):
    """
    Creates a boolean mask from sequence lengths.
    :param lengths: 1d tensor [batch_size]
    :param max_len: int
    """
    batch_size = lengths.numel()
    max_len = max_len or lengths.max()
    return (torch.arange(0, max_len, device=lengths.device)  # (0 for pad positions)
            .type_as(lengths)
            .repeat(batch_size, 1)
            .lt(lengths.unsqueeze(1)))