PyTorch Batch 屏蔽 select 实现

PyTorch Batch masked select implementation

如何执行 批处理 masked_select

鉴于:

x = torch.tensor([[1., 2., 2., 2., 3.],
                  [1., 2., 4., 3., 2.]])

所需的输出将是:

tensor([[1., 3., 1., 1., 1.],
        [1., 4., 3., 1., 1.]])

这是一个可能的方法:

x = torch.tensor([[1., 2., 2., 2., 3.],
                  [1., 2., 4., 3., 2.]])

ones = torch.tensor([[1., 1., 1., 1., 1.],
                     [1., 1., 1., 1., 1.]])

masks = torch.tensor([[ True, False, False, False,  True],
                      [ True, False,  True,  True, False]])

for i in range(x.size(0)):
    mask = masks[i]
    s = torch.masked_select(x[i], mask)
    ones[i][:s.size(0)] = s

是否有其他解决方案?

这类问题的主要问题是中间结果是非同质的:在您的批次中,元素将具有不同数量的掩码值。如果我们想应用 PyTorch 内置函数,这是一个问题。这里我提供两种解决方案来执行此操作。


1- 使用 list comprehension

通过适当数量的批处理元素、掩码和填充:

>>> pad = lambda v: F.pad(v, [0, len(m)-len(v)], value=1)
>>> torch.stack([pad(r[m]) for r, m in zip(x, masks)])
tensor([[1., 3., 1., 1., 1.],
        [1., 4., 3., 1., 1.]])

这很简单,与您的方法相似。


2- 使用 torch.scatter

向量化替代方案是构造正确的值和索引张量,以便我们可以应用 torch.scatter 并获得所需的结果。这里的技巧是使用扁平张量。从 xmasks 我们首先要访问 nzidx 定义为:

  • nz:来自 x 的非屏蔽值(当然使用 masks), 我们需要找到:

    tensor([1., 3., 1., 4., 3.]) 
    
  • idx:它们在输出张量中的对应索引展平

    tensor([ 0,  1,  5,  6,  7])
    

然后我们可以应用类似 out = ones.scatter(dim=0, idx, nz) 的散点图,这将有效地执行:out[idx[i]] = nz[i].

要构造nz,我们可以直接用masks索引masks非零值索引masks:

>>> nz = x[masks]
tensor([1., 3., 1., 4., 3.])

对于idx,这会有点棘手。我们可以对掩码本身进行排序,将其展平并使用 torch.Tensor.nonzero 获得非零值。排序后,True 值在每行的开头结束:

>>> idx = masks.sort(1, True).values.view(-1).nonzero()[:,0]
tensor([ 0,  1,  5,  6,  7])

最后我们可以应用 torch.scatter 并重塑以获得所需的结果:

>>> torch.ones(x.numel()).scatter(0, idx, nz).view_as(x)
tensor([[1., 3., 1., 1., 1.],
        [1., 4., 3., 1., 1.]])

此处torch.scatter的使用受到限制,因为输入是一维的。一个等效的方法是简单地:

>>> o = torch.ones(x.numel())
>>> o[idx] = nz
>>> o.view_as(x)

完整方法:

>>> idx = masks.sort(1, True)[0].view(-1).nonzero()[:,0]
>>> torch.ones(x.numel()).scatter(0, idx, x[masks]).view_as(x)
tensor([[1., 3., 1., 1., 1.],
        [1., 4., 3., 1., 1.]])