我如何 trim / 删除张量的一部分以匹配另一个张量的形状与 PyTorch?

How can I trim / remove part of a Tensor to match the shape of another Tensor with PyTorch?

我有 2 个张量:

outputs: torch.Size([4, 27, 161])       pred: torch.Size([4, 30, 161])

我想剪切 pred(从末尾开始),使其与 outputs 具有相同的尺寸。

使用 PyTorch 的最佳方法是什么?

您可以使用Narrow

例如:

a = torch.randn(4,30,161)
a.size() # torch.Size([4, 30, 161])
a.narrow(1,0,27).size() # torch.Size([4, 27, 161])

如果你有固定的两个张量的维数,你可以试试这个:

a = torch.randn(3, 5)
b = torch.zeros(3, 2)
b_h, b_w = b.shape
c = a[:b_h, :b_w]  # torch.Size([3, 2])

c 的形状与 b 相同,但与 a 的值相同。