有效地将数字添加到张量的每个维度的末尾
Efficiently add numbers to the end of each dimension of a tensor
我有一个形状为 (n, 200) 的张量 x。我想通过将 18 个数字的张量附加到当前张量的每个“行”的末尾来使其形状为 (n, 218)。 n 根据批量大小而变化,所以我想要一种方法来对任何 n 执行此操作。
截至目前,我有一个可行的解决方案,但我想知道是否有内置的方法可以做到这一点,我在文档中没有看到任何特别的东西。
我的方法是:
import torch.nn.functional as F
x = F.pad(input = x, (0, 18, 0, 0)) # pad each tensor in dim 2 with 18 zeroes
for index in range(x.shape[0]):
x[index][-18] = nums_to_add # nums_to_add is a tensor with size (1,18)
这很好用,但我想知道是否有更简单的方法来做到这一点,而无需先填充零。
torch.cat()
就是您要找的。这是一个片段:
import torch
a = torch.randint(1,10,(3,4))
b = torch.randint(1,10,(3,2))
print(a)
print(b)
a = torch.cat((a,b),axis=1) # axis should be one here
print(a)
输出
tensor([[2, 5, 3, 8],
[3, 9, 5, 3],
[9, 4, 9, 9]])
tensor([[6, 4],
[1, 1],
[8, 3]])
tensor([[2, 5, 3, 8, 6, 4],
[3, 9, 5, 3, 1, 1],
[9, 4, 9, 9, 8, 3]])
现在这里是一个类似的例子,只是使用 repeat
使它在 dim=0
中具有相同的形状,这样我们就可以轻松地将它连接起来。 (尝试完全遵循 OP 的建议)
import torch
a = torch.randint(1,10,(5,200)) # shape(5,200)
b = torch.randint(1,10,(1,18)).repeat((5,1)) # shape(5,18)
a = torch.cat((a,b),axis=1) # axis should be one here
print(a.shape) # (5,218)
上述解决方案中唯一棘手的部分是 repeat()
部分(如果你可以说它复杂的话......),它基本上沿着指定的维度重复这个张量。检查 here.
我有一个形状为 (n, 200) 的张量 x。我想通过将 18 个数字的张量附加到当前张量的每个“行”的末尾来使其形状为 (n, 218)。 n 根据批量大小而变化,所以我想要一种方法来对任何 n 执行此操作。
截至目前,我有一个可行的解决方案,但我想知道是否有内置的方法可以做到这一点,我在文档中没有看到任何特别的东西。
我的方法是:
import torch.nn.functional as F
x = F.pad(input = x, (0, 18, 0, 0)) # pad each tensor in dim 2 with 18 zeroes
for index in range(x.shape[0]):
x[index][-18] = nums_to_add # nums_to_add is a tensor with size (1,18)
这很好用,但我想知道是否有更简单的方法来做到这一点,而无需先填充零。
torch.cat()
就是您要找的。这是一个片段:
import torch
a = torch.randint(1,10,(3,4))
b = torch.randint(1,10,(3,2))
print(a)
print(b)
a = torch.cat((a,b),axis=1) # axis should be one here
print(a)
输出
tensor([[2, 5, 3, 8],
[3, 9, 5, 3],
[9, 4, 9, 9]])
tensor([[6, 4],
[1, 1],
[8, 3]])
tensor([[2, 5, 3, 8, 6, 4],
[3, 9, 5, 3, 1, 1],
[9, 4, 9, 9, 8, 3]])
现在这里是一个类似的例子,只是使用 repeat
使它在 dim=0
中具有相同的形状,这样我们就可以轻松地将它连接起来。 (尝试完全遵循 OP 的建议)
import torch
a = torch.randint(1,10,(5,200)) # shape(5,200)
b = torch.randint(1,10,(1,18)).repeat((5,1)) # shape(5,18)
a = torch.cat((a,b),axis=1) # axis should be one here
print(a.shape) # (5,218)
上述解决方案中唯一棘手的部分是 repeat()
部分(如果你可以说它复杂的话......),它基本上沿着指定的维度重复这个张量。检查 here.