用于将上一行的最大值添加到下一行的 PyTorch 张量操作

PyTorch Tensor Operation for adding the maximum of the previous row to the next

的跟进问题。

下面可以写成张量运算而不是循环吗?

a = torch.Tensor([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12]])

print(a.shape) 
# (3, 4)

for i in range(1, a.shape[0]):
    a[i] = a[i-1].max(dim=0)[0] + a[i]

print(a)
# tensor([[ 1,  2,  3,  4],
#         [ 9, 10, 11, 12],
#         [21, 22, 23, 24]])

基本上是将上一行的最大值加到下一行的所有元素上。

有趣的是,您无法预先计算每行的最大值,然后将其添加到相应的行,因为添加第一个最大值会影响下一行的最大值。

不完全确定您为什么要这样做,但是,是的,这是可能的。和你上一个问题基本一样:

max_vals, _ = a.max(axis=1, keepdim=True)
additions = max_vals.cumsum(0)[:-1]
a[1:, :] += additions

这是因为从一行到下一行的边际相加等于最大值,所以可以先取最大值,然后累加加到原来的张量上。