如何沿对角线替换 PyTorch 张量中的特定值?

How to replace specific values in PyTorch tensor along diagonal?

比如有一个PyTorch矩阵A:

A = tensor([[3,2,1],[1,0,2],[2,2,0]])

我需要将对角线上的 0 替换为 1,所以结果应该是:

tensor([[3,2,1],[1,1,2],[2,2,1]])

您可以使用向量索引提取对角线,对其进行处理,然后将其放回原始矩阵中:

N=10
a = torch.randint(0,N,[N,N])

#tensor([[0, 9, 6, 6, 9, 9, 3, 1, 8, 4],
#    [8, 1, 6, 8, 5, 8, 7, 8, 1, 4],
#    [1, 9, 8, 4, 7, 0, 2, 9, 6, 2],
#    [9, 5, 9, 6, 7, 1, 4, 0, 2, 6],
#    [1, 2, 8, 0, 9, 0, 4, 3, 9, 9],
#    [1, 4, 6, 9, 6, 5, 1, 2, 0, 7],
#    [4, 8, 1, 3, 1, 6, 1, 3, 5, 6],
#    [3, 8, 9, 9, 1, 3, 0, 9, 6, 6],
#    [7, 4, 3, 0, 3, 5, 6, 6, 9, 2],
#    [3, 1, 0, 8, 3, 5, 6, 6, 5, 5]])

diag = a[range(N),range(N)] #index (1,1), (2,2), ... etc
diag[diag==0] = 1 # set according to your condition
a[range(N),range(N)] = diag #return the diagonal to its place

#tensor([[1, 9, 6, 6, 9, 9, 3, 1, 8, 4],
#    [8, 1, 6, 8, 5, 8, 7, 8, 1, 4],
#    [1, 9, 8, 4, 7, 0, 2, 9, 6, 2],
#    [9, 5, 9, 6, 7, 1, 4, 0, 2, 6],
#    [1, 2, 8, 0, 9, 0, 4, 3, 9, 9],
#    [1, 4, 6, 9, 6, 5, 1, 2, 0, 7],
#    [4, 8, 1, 3, 1, 6, 1, 3, 5, 6],
#    [3, 8, 9, 9, 1, 3, 0, 9, 6, 6],
#    [7, 4, 3, 0, 3, 5, 6, 6, 9, 2],
#    [3, 1, 0, 8, 3, 5, 6, 6, 5, 5]])

您可以使用 torch 的内置对角线函数来替换对角线元素,如下所示:

mask = A.diagonal() == 0
A += torch.diag(mask)
>>> A
tensor([[3, 2, 1],
        [1, 1, 2],
        [2, 2, 1]])

如果您想用其他值替换 0,请将 mask 更改为 mask * replace_value