张量操作——从给定的张量创建位置张量

Tensor manipulation - creating a positional tensor from a given tensor

我有一个输入张量,它在开始时有零填充,然后是一系列值。所以像:

x = torch.tensor([[0, 2, 8, 12],
                  [0, 0, 6, 3]])

我需要的是另一个具有相同形状并为填充保留 0 并为其余数字保留递增序列的张量。所以我的输出张量应该是:

y = ([[0, 1, 2, 3],
      [0, 0, 1, 2]])

我试过类似的东西:

MAX_SEQ=4
seq_start = np.nonzero(x)
start = seq_start[0][0]
pos_id = torch.cat((torch.from_numpy(np.zeros(start, dtype=int)).to(device), torch.arange(1, MAX_SEQ-start+1).to(device)), 0)
print(pos_id)

如果张量是一维的但需要额外的逻辑来处理二维形状,这会起作用。这可以作为 np.nonzeros returns 一个元组来完成,我们可能会循环遍历那些更新计数器或其他东西的元组。但是我确信必须有一个简单的张量操作,它应该在 1-2 行代码中完成,而且可能更有效。

感谢帮助

三个小步骤的可能解决方案:

  1. 找到每一行的第一个非零元素的索引。这可以通过 解释的技巧来完成(此处适用于非二进制张量 )。

    > idx = torch.arange(x.shape[1], 0, -1)
    tensor([4, 3, 2, 1])
    
    > xbin = torch.where(x == 0, 0, 1)
    tensor([[0, 1, 1, 1],
            [0, 0, 1, 1]])
    
    > xbin*idx
    tensor([[0, 3, 2, 1],
            [0, 0, 2, 1]])
    
    > indices = torch.argmax(xbin*idx, dim=1, keepdim=True)
    tensor([[1],
            [2]])
    
  2. 为生成的张量创建一个排列(没有填充)。这可以通过在 torch.arange call:

    上应用 torch.repeattorch.view 来完成
    > rows, cols = x.shape
    > seq = torch.arange(1, cols+1).repeat(1, rows).view(-1, cols)
    tensor([[1, 2, 3, 4],
            [1, 2, 3, 4]])
    
  3. 最后 - 诀窍来了! - 对于每一行,我们用排列减去第一个非零元素的索引。然后我们屏蔽填充值并将它们替换为零:

    > pos_id = seq - indices
    tensor([[ 0,  1,  2,  3],
            [-1,  0,  1,  2]])
    
    > mask = indices > seq - 1
    tensor([[ True, False, False, False],
            [ True,  True, False, False]])
    
    > pos_id[mask] = 0
    tensor([[0, 1, 2, 3],
            [0, 0, 1, 2]])
    

扩展 Ivan 的好答案以包括批量大小,因为我的模型有那个。这 'seems' 起作用了。仅供参考,如果超过2D要考虑

x = torch.tensor([[[ 0,  0,  2,  3,  4,  5,  6,  7,  8,  9],
                [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]],

               [[0, 0, 0, 0, 0, 0, 26, 27, 28, 29],
                [0, 31, 32, 33, 34, 35, 36, 37, 38, 39]],

               [[0, 0, 42, 43, 44, 45, 46, 47, 48, 49],
                [0, 0, 0, 53, 0, 55, 56, 57, 58, 59]]])

bs, rows, cols = x.shape
seq = torch.arange(1, cols+1).repeat(1, rows).repeat(1, bs).view(bs, rows, cols)

idx = torch.arange(x.shape[-1], 0, -1)
xbin = torch.where(x == 0, 0, 1)
indices = torch.argmax(xbin*idx, dim=2, keepdim=True)

pos_id = seq - indices
mask = indices > seq - 1
pos_id[mask] = 0
print(pos_id)