在 PyTorch 中用矢量替换对角线元素

Replace diagonal elements with vector in PyTorch

我一直在到处寻找与 PyTorch 等价的东西,但找不到任何东西。

L_1 = np.tril(np.random.normal(scale=1., size=(D, D)), k=0)
L_1[np.diag_indices_from(L_1)] = np.exp(np.diagonal(L_1))

我想用 Pytorch 没有办法以如此优雅的方式替换对角线元素。

我认为目前还没有实现这样的功能。但是,您可以使用 mask 实现相同的功能,如下所示。

# Assuming v to be the vector and a be the tensor whose diagonal is to be replaced
mask = torch.diag(torch.ones_like(v))
out = mask*torch.diag(v) + (1. - mask)*a

因此,您的实施将类似于

L_1 = torch.tril(torch.randn((D, D)))
v = torch.exp(torch.diag(L_1))
mask = torch.diag(torch.ones_like(v))
L_1 = mask*torch.diag(v) + (1. - mask)*L_1

不如 numpy 优雅,但也不错。

有更简单的方法

dest_matrix[range(len(dest_matrix)), range(len(dest_matrix))] = source_vector

实际上我们必须自己生成对角线索引。

用法示例:

dest_matrix = torch.randint(10, (3, 3))
source_vector = torch.randint(100, 200, (len(dest_matrix), ))
print('dest_matrix:\n', dest_matrix)
print('source_vector:\n', source_vector)

dest_matrix[range(len(dest_matrix)), range(len(dest_matrix))] = source_vector

print('result:\n', dest_matrix)

# dest_matrix:
#  tensor([[3, 2, 5],
#         [0, 3, 5],
#         [3, 1, 1]])
# source_vector:
#  tensor([182, 169, 147])
# result:
#  tensor([[182,   2,   5],
#         [  0, 169,   5],
#         [  3,   1, 147]])

如果 dest_matrix 不是正方形,您必须在 range()

中使用 min(dest_matrix.size()) 而不是 len(dest_matrix)

不如numpy优雅,但这不需要存储新的索引矩阵。

是的,这保留了梯度

您可以使用 diagonal() 提取对角线元素,然后使用 copy_():

就地分配转换后的值
new_diags = L_1.diagonal().exp()
L_1.diagonal().copy_(new_diags)

为简单起见,假设您有一个矩阵 L_1 并且想用零替换它的对角线。您可以通过多种方式执行此操作。

使用fill_diagonal_():

L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)
L_1 = L_1.fill_diagonal_(0.)

使用高级索引:

L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)
length = len(L_1)
zero_vector = torch.zeros(length, dtype=torch.float32)
L_1[range(length), range(length)] = zero_vector

使用scatter_():

L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)  
diag_idx = torch.arange(len(L_1)).unsqueeze(1)
zero_matrix = torch.zeros(L_1.shape, dtype=torch.float32)
L_1 = L_1.scatter_(1, diag_idx, zero_matrix) 

请注意,上述所有解决方案都是 in-place 操作,并且会影响向后传递,因为可能需要原始值来计算它。因此,如果你想保持向后传递不受影响,意思是通过不记录变化(操作)来“打破图表”,这意味着不计算与你在前向传递中计算的内容相对应的向后传递中的梯度,那么你可以使用高级索引时添加 .datascatter_().

使用高级索引 .data:

L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)
length = len(L_1)
zero_vector = torch.zeros(length, dtype=torch.float32)
L_1[range(length), range(length)] = zero_vector.data

scatter_().data 一起使用:

L_1 = torch.tensor([[1,2,3],[4,5,6],[7,8,9]], dtype=torch.float32)  
diag_idx = torch.arange(len(L_1)).unsqueeze(1)
zero_matrix = torch.zeros(L_1.shape, dtype=torch.float32)
L_1 = L_1.scatter_(1, diag_idx, zero_matrix.data)

参考 this 讨论。