如何使用 PyTorch 快速反转排列?
How to quickly inverse a permutation by using PyTorch?
我对如何快速恢复被排列打乱的数组感到困惑。
示例 #1:
[x, y, z]
被P: [2, 0, 1]
打乱,我们将得到[z, x, y]
- 对应的逆应该是
P^-1: [1, 2, 0]
示例 #2:
[a, b, c, d, e, f]
被 P: [5, 2, 0, 1, 4, 3]
洗牌,然后我们将得到 [f, c, a, b, e, d]
- 对应的逆应该是
P^-1: [2, 3, 1, 5, 4, 0]
我基于矩阵乘法(置换矩阵的转置是它的逆矩阵)写了下面的代码,但是当我用它来训练我的模型时,这种方法太慢了。有没有更快的实现?
import torch
n = 10
x = torch.Tensor(list(range(n)))
print('Original array', x)
random_perm_indices = torch.randperm(n).long()
perm_matrix = torch.eye(n)[random_perm_indices].t()
x = x[random_perm_indices]
print('Shuffled', x)
restore_indices = torch.Tensor(list(range(n))).view(n, 1)
restore_indices = perm_matrix.mm(restore_indices).view(n).long()
x = x[restore_indices]
print('Restored', x)
我在PyTorch Forum中得到了解决方案。
>>> import torch
>>> torch.__version__
'1.7.1'
>>> p1 = torch.tensor ([2, 0, 1])
>>> torch.argsort (p1)
tensor([1, 2, 0])
>>> p2 = torch.tensor ([5, 2, 0, 1, 4, 3])
>>> torch.argsort (p2)
tensor([2, 3, 1, 5, 4, 0])
更新:
由于其线性时间复杂度,以下解决方案更有效:
def inverse_permutation(perm):
inv = torch.empty_like(perm)
inv[perm] = torch.arange(perm.size(0), device=perm.device)
return inv
我对如何快速恢复被排列打乱的数组感到困惑。
示例 #1:
[x, y, z]
被P: [2, 0, 1]
打乱,我们将得到[z, x, y]
- 对应的逆应该是
P^-1: [1, 2, 0]
示例 #2:
[a, b, c, d, e, f]
被P: [5, 2, 0, 1, 4, 3]
洗牌,然后我们将得到[f, c, a, b, e, d]
- 对应的逆应该是
P^-1: [2, 3, 1, 5, 4, 0]
我基于矩阵乘法(置换矩阵的转置是它的逆矩阵)写了下面的代码,但是当我用它来训练我的模型时,这种方法太慢了。有没有更快的实现?
import torch
n = 10
x = torch.Tensor(list(range(n)))
print('Original array', x)
random_perm_indices = torch.randperm(n).long()
perm_matrix = torch.eye(n)[random_perm_indices].t()
x = x[random_perm_indices]
print('Shuffled', x)
restore_indices = torch.Tensor(list(range(n))).view(n, 1)
restore_indices = perm_matrix.mm(restore_indices).view(n).long()
x = x[restore_indices]
print('Restored', x)
我在PyTorch Forum中得到了解决方案。
>>> import torch
>>> torch.__version__
'1.7.1'
>>> p1 = torch.tensor ([2, 0, 1])
>>> torch.argsort (p1)
tensor([1, 2, 0])
>>> p2 = torch.tensor ([5, 2, 0, 1, 4, 3])
>>> torch.argsort (p2)
tensor([2, 3, 1, 5, 4, 0])
更新: 由于其线性时间复杂度,以下解决方案更有效:
def inverse_permutation(perm):
inv = torch.empty_like(perm)
inv[perm] = torch.arange(perm.size(0), device=perm.device)
return inv