如何右移矩阵的每一行?

How to right shift each row of a matrix?

我有一个矩阵,其形状为 (TxK, and K << T)。我想将其扩展为 TxT 形状,然后将第 i 行右移 i 步。

举个例子:

inputs: T= 5, and K = 3
1 2 3
1 2 3
1 2 3
1 2 3
1 2 3

expected outputs:
1 2 3 0 0
0 1 2 3 0
0 0 1 2 3
0 0 0 1 2
0 0 0 0 1

我的解决方案:

right_pad = T - K + 1
output = F.pad(input, (0, right_pad), 'constant', value=0)
output = output.view(-1)[:-T].view(T, T)

我的解决方案会导致错误 -- gradient computation has been modified by an in-place operation。有没有一种高效可行的方法可以达到我的目的?

您可以使用 PyTorch 逐列执行此操作。

# input is a T * K tensor
input = torch.ones((T, K))

index = torch.tensor(np.linspace(0, T - 1, num=T, dtype=np.int64))
output = torch.zeros((T, T))
output[index, index] = input[:, 0]
for k in range(1, K):
    output[index[:-k], index[:-k] + k] = input[:-k, k]
print(output)

您的功能很好,不是导致错误的原因(使用PyTorch 1.6.0,如果您使用的是其他版本,请更新您的依赖项)。

下面的代码工作正常:

import torch
import torch.nn as nn
import torch.nn.functional as F

T = 5
K = 3

inputs = torch.tensor(
    [[1, 2, 3,], [1, 2, 3,], [1, 2, 3,], [1, 2, 3,], [1, 2, 3,],],
    requires_grad=True,
    dtype=torch.float,
)

right_pad = T - K + 1
output = F.pad(inputs, (0, right_pad), "constant", value=0)
output = output.flatten()[:-T].reshape(T, T)

output.sum().backward()

print(inputs.grad)

请注意,我已将 dtype 明确指定为 torch.float,因为您不能 backprop 整数。

viewslice 永远不会 中断反向传播,因为 gradient 连接到单个值,无论它是否被查看作为 1D 或未压缩的 2D 或其他。那些没有修改 in-place。 In-place修改打破梯度可以是:

output[0, 3] = 15.

另外,你的解决方案returns这个:

tensor([[1., 2., 3., 0., 0.],
        [0., 1., 2., 3., 0.],
        [0., 0., 1., 2., 3.],
        [0., 0., 0., 1., 2.],
        [3., 0., 0., 0., 1.]], grad_fn=<ViewBackward>)

所以你在左下角有一个 3。如果这不是您所期望的,您应该在 output = output.flatten()[:-T].reshape(T,T):

之后添加这一行(乘以 1 的上三角矩阵)
output *= torch.triu(torch.ones_like(output))

给出:

tensor([[1., 2., 3., 0., 0.],
        [0., 1., 2., 3., 0.],
        [0., 0., 1., 2., 3.],
        [0., 0., 0., 1., 2.],
        [0., 0., 0., 0., 1.]], grad_fn=<AsStridedBackward>)

inputs.grad

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 0.],
        [1., 0., 0.]])