PyTorch 中的 Concat 张量

Concat tensors in PyTorch

我有一个名为 data 的张量,形状为 [128, 4, 150, 150],其中 128 是批量大小,4 是通道数,最后 2 个维度是高度和宽度。我有另一个名为 fake 的张量,形状为 [128, 1, 150, 150].

我想从 data 的第二个维度中删除最后一个 list/array;数据的形状现在是 [128, 3, 150, 150];并将其与 fake 连接起来,给出连接的输出维度 [128, 4, 150, 150].

基本上,换句话说,我想将 data 的前 3 个维度与 fake 连接起来,得到一个 4 维张量。

我正在使用 PyTorch 并遇到了函数 torch.cat()torch.stack()

这是我编写的示例代码:

fake_combined = []
        for j in range(batch_size):
            fake_combined.append(torch.stack((data[j][0].to(device), data[j][1].to(device), data[j][2].to(device), fake[j][0].to(device))))
fake_combined = torch.tensor(fake_combined, dtype=torch.float32)
fake_combined = fake_combined.to(device)

但是我在行中收到错误:

fake_combined = torch.tensor(fake_combined, dtype=torch.float32)

错误是:

ValueError: only one element tensors can be converted to Python scalars

此外,如果我打印 fake_combined 的形状,我得到的输出是 [128,] 而不是 [128, 4, 150, 150]

当我打印 fake_combined[0] 的形状时,我得到的输出是 [4, 150, 150],这是预期的。

所以我的问题是,为什么我无法使用 torch.tensor() 将列表转换为张量。我错过了什么吗?有没有更好的方法来做我打算做的事情?

任何帮助将不胜感激!谢谢!

您也可以只分配给那个特定的维度。

orig = torch.randint(low=0, high=10, size=(2,3,2,2))
fake = torch.randint(low=111, high=119, size=(2,1,2,2))
orig[:,[2],:,:] = fake

原始之前

tensor([[[[0, 1],
      [8, 0]],

     [[4, 9],
      [6, 1]],

     [[8, 2],
      [7, 6]]],


    [[[1, 1],
      [8, 5]],

     [[5, 0],
      [8, 6]],

     [[5, 5],
      [2, 8]]]])

假的

tensor([[[[117, 115],
      [114, 111]]],


    [[[115, 115],
      [118, 115]]]])

原版之后

tensor([[[[  0,   1],
      [  8,   0]],

     [[  4,   9],
      [  6,   1]],

     [[117, 115],
      [114, 111]]],


    [[[  1,   1],
      [  8,   5]],

     [[  5,   0],
      [  8,   6]],

     [[115, 115],
      [118, 115]]]])

希望对您有所帮助! :)

@rollthedice32 的回答非常好。出于教育目的,这里使用 torch.cat

a = torch.rand(128, 4, 150, 150)
b = torch.rand(128, 1, 150, 150)

# Cut out last dimension
a = a[:, :3, :, :]
# Concatenate in 2nd dimension
result = torch.cat([a, b], dim=1)
print(result.shape)
# => torch.Size([128, 4, 150, 150])