张量的 Pytorch 成对连接

Pytorch pairwise concatenation of tensors

我想以批处理方式计算特定维度上的成对串联。

例如,

x = torch.tensor([[[0],[1],[2]],[[3],[4],[5]]])
x.shape = torch.Size([2, 3, 1])

我想得到 y 使得 y 是一维所有向量对的串联,即:

y = torch.tensor([[[[0,0],[0,1],[0,2]],[[1,0],[1,1],[1,2]], [[2,0], [2,1], [2,2]]], 
                 [[[3,3],[3,4],[3,5]],[[4,3],[4,4],[4,5]], [[5,3],[5,4],[5,5]]]])

y.shape = torch.Size([2, 3, 3, 2])

所以基本上,对于每个 x[i,:],您生成所有向量对并将它们连接到最后一个维度。 有直接的方法吗?

一种可能的方法是:

    all_ordered_idx_pairs = torch.cartesian_prod(torch.tensor(range(x.shape[1])),torch.tensor(range(x.shape[1])))
    y = torch.stack([x[i][all_ordered_idx_pairs] for i in range(x.shape[0])])

重塑张量后:

y = y.view(x.shape[0], x.shape[1], x.shape[1], -1)

你得到:

y = torch.tensor([[[[0,0],[0,1],[0,2]],[[1,0],[1,1],[1,2]], [[2,0], [2,1], [2,2]]], 
                 [[[3,3],[3,4],[3,5]],[[4,3],[4,4],[4,5]], [[5,3],[5,4],[5,5]]]])

没有循环并使用 torch.arange()。诀窍是广播而不是使用 for 循环。这会将操作应用于具有 : 字符的维度中的所有元素。

x = torch.tensor([
    [[0.0000, 1.0000, 2.0000],
     [3.0000, 4.0000, 5.0000],
     [0.0000, -1.0000, -2.0000],
     [-3.0000, -4.0000, -5.0000]],
    [[0.0000, 10.0000, 20.0000],
     [30.0000, 40.0000, 50.0000],
     [0.0000, -10.0000, -20.0000],
     [-30.0000, -40.0000, -50.0000]
     ]
])
​
idx_pairs = torch.cartesian_prod(torch.arange(x.shape[1]), torch.arange(x.shape[1]))
y = x[:, idx_pairs].view(x.shape[0], x.shape[1], x.shape[1], -1)
tensor([[[[  0.,   1.,   2.,   0.,   1.,   2.],
          [  0.,   1.,   2.,   3.,   4.,   5.],
          [  0.,   1.,   2.,   0.,  -1.,  -2.],
          [  0.,   1.,   2.,  -3.,  -4.,  -5.]],
         [[  3.,   4.,   5.,   0.,   1.,   2.],
          [  3.,   4.,   5.,   3.,   4.,   5.],
          [  3.,   4.,   5.,   0.,  -1.,  -2.],
          [  3.,   4.,   5.,  -3.,  -4.,  -5.]],
         [[  0.,  -1.,  -2.,   0.,   1.,   2.],
          [  0.,  -1.,  -2.,   3.,   4.,   5.],
          [  0.,  -1.,  -2.,   0.,  -1.,  -2.],
          [  0.,  -1.,  -2.,  -3.,  -4.,  -5.]],
         [[ -3.,  -4.,  -5.,   0.,   1.,   2.],
          [ -3.,  -4.,  -5.,   3.,   4.,   5.],
          [ -3.,  -4.,  -5.,   0.,  -1.,  -2.],
          [ -3.,  -4.,  -5.,  -3.,  -4.,  -5.]]],
        [[[  0.,  10.,  20.,   0.,  10.,  20.],
          [  0.,  10.,  20.,  30.,  40.,  50.],
          [  0.,  10.,  20.,   0., -10., -20.],
          [  0.,  10.,  20., -30., -40., -50.]],
         [[ 30.,  40.,  50.,   0.,  10.,  20.],
          [ 30.,  40.,  50.,  30.,  40.,  50.],
          [ 30.,  40.,  50.,   0., -10., -20.],
          [ 30.,  40.,  50., -30., -40., -50.]],
         [[  0., -10., -20.,   0.,  10.,  20.],
          [  0., -10., -20.,  30.,  40.,  50.],
          [  0., -10., -20.,   0., -10., -20.],
          [  0., -10., -20., -30., -40., -50.]],
         [[-30., -40., -50.,   0.,  10.,  20.],
          [-30., -40., -50.,  30.,  40.,  50.],
          [-30., -40., -50.,   0., -10., -20.],
          [-30., -40., -50., -30., -40., -50.]]]])