多维张量的 Top K 指数
Top K indices of a multi-dimensional tensor
我有一个二维张量,我想获取前 k 个值的索引。我知道 pytorch's topk 函数。 pytorch 的 topk 函数的问题在于,它计算某个维度上的 topk 值。我想获得两个维度的 topk 值。
例如下面的张量
a = torch.tensor([[4, 9, 7, 4, 0],
[8, 1, 3, 1, 0],
[9, 8, 4, 4, 8],
[0, 9, 4, 7, 8],
[8, 8, 0, 1, 4]])
pytorch 的 topk 函数会给我以下信息。
values, indices = torch.topk(a, 3)
print(indices)
# tensor([[1, 2, 0],
# [0, 2, 1],
# [0, 1, 4],
# [1, 4, 3],
# [1, 0, 4]])
但我想得到以下
tensor([[0, 1],
[2, 0],
[3, 1]])
这是二维张量中9的索引。
是否有任何方法可以使用 pytorch 实现此目的?
您可以 flatten
原始张量,应用 topk
然后将结果标量索引转换回多维索引,如下所示:
def descalarization(idx, shape):
res = []
N = np.prod(shape)
for n in shape:
N //= n
res.append(idx // N)
idx %= N
return tuple(res)
示例:
torch.tensor([descalarization(k, a.size()) for k in torch.topk(a.flatten(), 5).indices])
# Returns
# tensor([[3, 1],
# [2, 0],
# [0, 1],
# [3, 4],
# [2, 4]])
v, i = torch.topk(a.flatten(), 3)
print (np.array(np.unravel_index(i.numpy(), a.shape)).T)
输出:
[[3 1]
[2 0]
[0 1]]
- 展平并找到顶部 k
- 使用
unravel_index
将 1D 索引转换为 2D
你可以根据自己的需要做一些向量运算来过滤。在这种情况下不使用 topk。
print(a)
tensor([[4, 9, 7, 4, 0],
[8, 1, 3, 1, 0],
[9, 8, 4, 4, 8],
[0, 9, 4, 7, 8],
[8, 8, 0, 1, 4]])
values, indices = torch.max(a,1) # get max values, indices
temp= torch.zeros_like(values) # temporary
temp[values==9]=1 # fill temp where values are 9 (wished value)
seq=torch.arange(values.shape[0]) # create a helper sequence
new_seq=seq[temp>0] # filter sequence where values are 9
new_temp=indices[new_seq] # filter indices with sequence where values are 9
final = torch.stack([new_seq, new_temp], dim=1) # stack both to get result
print(final)
tensor([[0, 1],
[2, 0],
[3, 1]])
我有一个二维张量,我想获取前 k 个值的索引。我知道 pytorch's topk 函数。 pytorch 的 topk 函数的问题在于,它计算某个维度上的 topk 值。我想获得两个维度的 topk 值。
例如下面的张量
a = torch.tensor([[4, 9, 7, 4, 0],
[8, 1, 3, 1, 0],
[9, 8, 4, 4, 8],
[0, 9, 4, 7, 8],
[8, 8, 0, 1, 4]])
pytorch 的 topk 函数会给我以下信息。
values, indices = torch.topk(a, 3)
print(indices)
# tensor([[1, 2, 0],
# [0, 2, 1],
# [0, 1, 4],
# [1, 4, 3],
# [1, 0, 4]])
但我想得到以下
tensor([[0, 1],
[2, 0],
[3, 1]])
这是二维张量中9的索引。
是否有任何方法可以使用 pytorch 实现此目的?
您可以 flatten
原始张量,应用 topk
然后将结果标量索引转换回多维索引,如下所示:
def descalarization(idx, shape):
res = []
N = np.prod(shape)
for n in shape:
N //= n
res.append(idx // N)
idx %= N
return tuple(res)
示例:
torch.tensor([descalarization(k, a.size()) for k in torch.topk(a.flatten(), 5).indices])
# Returns
# tensor([[3, 1],
# [2, 0],
# [0, 1],
# [3, 4],
# [2, 4]])
v, i = torch.topk(a.flatten(), 3)
print (np.array(np.unravel_index(i.numpy(), a.shape)).T)
输出:
[[3 1]
[2 0]
[0 1]]
- 展平并找到顶部 k
- 使用
unravel_index
将 1D 索引转换为 2D
你可以根据自己的需要做一些向量运算来过滤。在这种情况下不使用 topk。
print(a)
tensor([[4, 9, 7, 4, 0],
[8, 1, 3, 1, 0],
[9, 8, 4, 4, 8],
[0, 9, 4, 7, 8],
[8, 8, 0, 1, 4]])
values, indices = torch.max(a,1) # get max values, indices
temp= torch.zeros_like(values) # temporary
temp[values==9]=1 # fill temp where values are 9 (wished value)
seq=torch.arange(values.shape[0]) # create a helper sequence
new_seq=seq[temp>0] # filter sequence where values are 9
new_temp=indices[new_seq] # filter indices with sequence where values are 9
final = torch.stack([new_seq, new_temp], dim=1) # stack both to get result
print(final)
tensor([[0, 1],
[2, 0],
[3, 1]])