如何有效地检索 Torch 张量中最大值的索引?
How to efficiently retrieve the indices of maximum values in a Torch tensor?
假设有一个火炬张量,例如以下形状:
x = torch.rand(20, 1, 120, 120)
我现在想要的是获取每个 120x120 矩阵的最大值的索引。为了简化问题,我会先 x.squeeze()
使用形状 [20, 120, 120]
。然后我想得到 torch 张量,它是形状为 [20, 2]
.
的索引列表
我怎样才能快速做到这一点?
如果我没听错,你不需要值,而是索引。不幸的是,没有开箱即用的解决方案。存在一个 argmax()
函数,但我看不出如何让它完全按照您的要求执行。
所以这是一个小的解决方法,效率应该还可以,因为我们只是划分张量:
n = torch.tensor(4)
d = torch.tensor(4)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
# since argmax() does only return the index of the flattened
# matrix block we have to calculate the indices by ourself
# by using / and % (// would also work, but as we are dealing with
# type torch.long / works as well
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
print(x)
print(indices)
n
代表您的第一个维度,d
代表最后两个维度。我在这里使用较小的数字来显示结果。但当然这也适用于 n=20
和 d=120
:
n = torch.tensor(20)
d = torch.tensor(120)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
#print(x)
print(indices)
这是 n=4
和 d=4
的输出:
tensor([[[[0.3699, 0.3584, 0.4940, 0.8618],
[0.6767, 0.7439, 0.5984, 0.5499],
[0.8465, 0.7276, 0.3078, 0.3882],
[0.1001, 0.0705, 0.2007, 0.4051]]],
[[[0.7520, 0.4528, 0.0525, 0.9253],
[0.6946, 0.0318, 0.5650, 0.7385],
[0.0671, 0.6493, 0.3243, 0.2383],
[0.6119, 0.7762, 0.9687, 0.0896]]],
[[[0.3504, 0.7431, 0.8336, 0.0336],
[0.8208, 0.9051, 0.1681, 0.8722],
[0.5751, 0.7903, 0.0046, 0.1471],
[0.4875, 0.1592, 0.2783, 0.6338]]],
[[[0.9398, 0.7589, 0.6645, 0.8017],
[0.9469, 0.2822, 0.9042, 0.2516],
[0.2576, 0.3852, 0.7349, 0.2806],
[0.7062, 0.1214, 0.0922, 0.1385]]]])
tensor([[0, 3],
[3, 2],
[1, 1],
[1, 0]])
希望这就是您想要的! :)
编辑:
这是一个稍微修改过的版本,它可能会稍微快一些(我猜不是很多 :),但它更简单、更漂亮:
而不是像以前那样:
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m // d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
已对 argmax
值进行了必要的重塑:
m = x.view(n, -1).argmax(1).view(-1, 1)
indices = torch.cat((m // d, m % d), dim=1)
但是正如评论中提到的那样。我认为不可能从中得到更多。
如果真的对您来说从中获得最后一点性能改进很重要,那么您可以做的一件事是将上述功能实现为低- pytorch 的级别扩展(类似于 C++)。
这只会给你一个函数,你可以调用它并且会避免缓慢的 python 代码。
torch.topk() 就是您要找的。从文档中,
torch.topk
(input, k, dim=None, largest=True,
sorted=True, out=None) -> (Tensor, LongTensor)
Returns k
给定 input
张量的最大元素
给定的维度。
如果未给出dim
,则选择输入的最后一个维度。
如果largest
是False
则返回k个最小的元素。
返回一个名为(values, indices)的元组,其中indices是原始输入张量中元素的索引。
布尔选项 sorted
如果 True
,将确保返回的 k 个元素本身已排序
ps=ps.numpy()
ps=ps.tolist()
mx=[max(l) for l in ps]
mx=max(mx)
for i in range(len(ps[0])):
if mx==ps[0][i]:
print("The digit is "+str(i))
break
这对我来说非常有用
这是 torch
中的 unravel_index
实现:
def unravel_index(
indices: torch.LongTensor,
shape: Tuple[int, ...],
) -> torch.LongTensor:
r"""Converts flat indices into unraveled coordinates in a target shape.
This is a `torch` implementation of `numpy.unravel_index`.
Args:
indices: A tensor of (flat) indices, (*, N).
shape: The targeted shape, (D,).
Returns:
The unraveled coordinates, (*, N, D).
"""
coord = []
for dim in reversed(shape):
coord.append(indices % dim)
indices = indices // dim
coord = torch.stack(coord[::-1], dim=-1)
return coord
然后,您可以使用 torch.argmax
函数获取“展平”张量的索引。
y = x.view(20, -1)
indices = torch.argmax(y)
indices.shape # (20,)
并使用 unravel_index
函数解开索引。
indices = unravel_index(indices, x.shape[-2:])
indices.shape # (20, 2)
已接受的答案仅适用于给定的示例。
tejasvi88 的回答很有趣,但无助于回答原始问题(正如我在那里的评论中所解释的)。
我相信 Francois 的答案是最接近的,因为它处理的是更一般的情况(任意数量的维度)。但是,它与 argmax
无关,并且显示的示例并未说明该函数处理批处理的能力。
所以我将在 Francois 的回答的基础上添加代码以连接到 argmax
。我编写了一个新函数 batch_argmax
,即 returns 批处理中最大值的索引。批次可以按多个维度进行组织。我还包括一些测试用例以供说明:
def batch_argmax(tensor, batch_dim=1):
"""
Assumes that dimensions of tensor up to batch_dim are "batch dimensions"
and returns the indices of the max element of each "batch row".
More precisely, returns tensor `a` such that, for each index v of tensor.shape[:batch_dim], a[v] is
the indices of the max element of tensor[v].
"""
if batch_dim >= len(tensor.shape):
raise NoArgMaxIndices()
batch_shape = tensor.shape[:batch_dim]
non_batch_shape = tensor.shape[batch_dim:]
flat_non_batch_size = prod(non_batch_shape)
tensor_with_flat_non_batch_portion = tensor.reshape(*batch_shape, flat_non_batch_size)
dimension_of_indices = len(non_batch_shape)
# We now have each batch row flattened in the last dimension of tensor_with_flat_non_batch_portion,
# so we can invoke its argmax(dim=-1) method. However, that method throws an exception if the tensor
# is empty. We cover that case first.
if tensor_with_flat_non_batch_portion.numel() == 0:
# If empty, either the batch dimensions or the non-batch dimensions are empty
batch_size = prod(batch_shape)
if batch_size == 0: # if batch dimensions are empty
# return empty tensor of appropriate shape
batch_of_unraveled_indices = torch.ones(*batch_shape, dimension_of_indices).long() # 'ones' is irrelevant as it will be empty
else: # non-batch dimensions are empty, so argmax indices are undefined
raise NoArgMaxIndices()
else: # We actually have elements to maximize, so we search for them
indices_of_non_batch_portion = tensor_with_flat_non_batch_portion.argmax(dim=-1)
batch_of_unraveled_indices = unravel_indices(indices_of_non_batch_portion, non_batch_shape)
if dimension_of_indices == 1:
# above function makes each unraveled index of a n-D tensor a n-long tensor
# however indices of 1D tensors are typically represented by scalars, so we squeeze them in this case.
batch_of_unraveled_indices = batch_of_unraveled_indices.squeeze(dim=-1)
return batch_of_unraveled_indices
class NoArgMaxIndices(BaseException):
def __init__(self):
super(NoArgMaxIndices, self).__init__(
"no argmax indices: batch_argmax requires non-batch shape to be non-empty")
下面是测试:
def test_basic():
# a simple array
tensor = torch.tensor([0, 1, 2, 3, 4])
batch_dim = 0
expected = torch.tensor(4)
run_test(tensor, batch_dim, expected)
# making batch_dim = 1 renders the non-batch portion empty and argmax indices undefined
tensor = torch.tensor([0, 1, 2, 3, 4])
batch_dim = 1
check_that_exception_is_thrown(lambda: batch_argmax(tensor, batch_dim), NoArgMaxIndices)
# now a batch of arrays
tensor = torch.tensor([[1, 2, 3], [6, 5, 4]])
batch_dim = 1
expected = torch.tensor([2, 0])
run_test(tensor, batch_dim, expected)
# Now we have an empty batch with non-batch 3-dim arrays' shape (the arrays are actually non-existent)
tensor = torch.ones(0, 3) # 'ones' is irrelevant since this is empty
batch_dim = 1
# empty batch of the right shape: just the batch dimension 0,since indices of arrays are scalar (0D)
expected = torch.ones(0)
run_test(tensor, batch_dim, expected)
# Now we have an empty batch with non-batch matrices' shape (the matrices are actually non-existent)
tensor = torch.ones(0, 3, 2) # 'ones' is irrelevant since this is empty
batch_dim = 1
# empty batch of the right shape: the batch and two dimension for the indices since we have 2D matrices
expected = torch.ones(0, 2)
run_test(tensor, batch_dim, expected)
# a batch of 2D matrices:
tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
batch_dim = 1
expected = torch.tensor([[1, 0], [1, 2]]) # coordinates of two 6's, one in each 2D matrix
run_test(tensor, batch_dim, expected)
# same as before, but testing that batch_dim supports negative values
tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
batch_dim = -2
expected = torch.tensor([[1, 0], [1, 2]])
run_test(tensor, batch_dim, expected)
# Same data, but a 2-dimensional batch of 1D arrays!
tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
batch_dim = 2
expected = torch.tensor([[2, 0], [1, 2]]) # coordinates of 3, 6, 3, and 6
run_test(tensor, batch_dim, expected)
# same as before, but testing that batch_dim supports negative values
tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
batch_dim = -1
expected = torch.tensor([[2, 0], [1, 2]])
run_test(tensor, batch_dim, expected)
def run_test(tensor, batch_dim, expected):
actual = batch_argmax(tensor, batch_dim)
print(f"batch_argmax of {tensor} with batch_dim {batch_dim} is\n{actual}\nExpected:\n{expected}")
assert actual.shape == expected.shape
assert actual.eq(expected).all()
def check_that_exception_is_thrown(thunk, exception_type):
if isinstance(exception_type, BaseException):
raise Exception(f"check_that_exception_is_thrown received an exception instance rather than an exception type: "
f"{exception_type}")
try:
thunk()
raise AssertionError(f"Should have thrown {exception_type}")
except exception_type:
pass
except Exception as e:
raise AssertionError(f"Should have thrown {exception_type} but instead threw {e}")
假设有一个火炬张量,例如以下形状:
x = torch.rand(20, 1, 120, 120)
我现在想要的是获取每个 120x120 矩阵的最大值的索引。为了简化问题,我会先 x.squeeze()
使用形状 [20, 120, 120]
。然后我想得到 torch 张量,它是形状为 [20, 2]
.
我怎样才能快速做到这一点?
如果我没听错,你不需要值,而是索引。不幸的是,没有开箱即用的解决方案。存在一个 argmax()
函数,但我看不出如何让它完全按照您的要求执行。
所以这是一个小的解决方法,效率应该还可以,因为我们只是划分张量:
n = torch.tensor(4)
d = torch.tensor(4)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
# since argmax() does only return the index of the flattened
# matrix block we have to calculate the indices by ourself
# by using / and % (// would also work, but as we are dealing with
# type torch.long / works as well
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
print(x)
print(indices)
n
代表您的第一个维度,d
代表最后两个维度。我在这里使用较小的数字来显示结果。但当然这也适用于 n=20
和 d=120
:
n = torch.tensor(20)
d = torch.tensor(120)
x = torch.rand(n, 1, d, d)
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m / d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
#print(x)
print(indices)
这是 n=4
和 d=4
的输出:
tensor([[[[0.3699, 0.3584, 0.4940, 0.8618],
[0.6767, 0.7439, 0.5984, 0.5499],
[0.8465, 0.7276, 0.3078, 0.3882],
[0.1001, 0.0705, 0.2007, 0.4051]]],
[[[0.7520, 0.4528, 0.0525, 0.9253],
[0.6946, 0.0318, 0.5650, 0.7385],
[0.0671, 0.6493, 0.3243, 0.2383],
[0.6119, 0.7762, 0.9687, 0.0896]]],
[[[0.3504, 0.7431, 0.8336, 0.0336],
[0.8208, 0.9051, 0.1681, 0.8722],
[0.5751, 0.7903, 0.0046, 0.1471],
[0.4875, 0.1592, 0.2783, 0.6338]]],
[[[0.9398, 0.7589, 0.6645, 0.8017],
[0.9469, 0.2822, 0.9042, 0.2516],
[0.2576, 0.3852, 0.7349, 0.2806],
[0.7062, 0.1214, 0.0922, 0.1385]]]])
tensor([[0, 3],
[3, 2],
[1, 1],
[1, 0]])
希望这就是您想要的! :)
编辑:
这是一个稍微修改过的版本,它可能会稍微快一些(我猜不是很多 :),但它更简单、更漂亮:
而不是像以前那样:
m = x.view(n, -1).argmax(1)
indices = torch.cat(((m // d).view(-1, 1), (m % d).view(-1, 1)), dim=1)
已对 argmax
值进行了必要的重塑:
m = x.view(n, -1).argmax(1).view(-1, 1)
indices = torch.cat((m // d, m % d), dim=1)
但是正如评论中提到的那样。我认为不可能从中得到更多。
如果真的对您来说从中获得最后一点性能改进很重要,那么您可以做的一件事是将上述功能实现为低- pytorch 的级别扩展(类似于 C++)。
这只会给你一个函数,你可以调用它并且会避免缓慢的 python 代码。
torch.topk() 就是您要找的。从文档中,
torch.topk
(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
Returns k
给定 input
张量的最大元素
给定的维度。
如果未给出
dim
,则选择输入的最后一个维度。如果
largest
是False
则返回k个最小的元素。返回一个名为(values, indices)的元组,其中indices是原始输入张量中元素的索引。
布尔选项
sorted
如果True
,将确保返回的 k 个元素本身已排序
ps=ps.numpy()
ps=ps.tolist()
mx=[max(l) for l in ps]
mx=max(mx)
for i in range(len(ps[0])):
if mx==ps[0][i]:
print("The digit is "+str(i))
break
这对我来说非常有用
这是 torch
中的 unravel_index
实现:
def unravel_index(
indices: torch.LongTensor,
shape: Tuple[int, ...],
) -> torch.LongTensor:
r"""Converts flat indices into unraveled coordinates in a target shape.
This is a `torch` implementation of `numpy.unravel_index`.
Args:
indices: A tensor of (flat) indices, (*, N).
shape: The targeted shape, (D,).
Returns:
The unraveled coordinates, (*, N, D).
"""
coord = []
for dim in reversed(shape):
coord.append(indices % dim)
indices = indices // dim
coord = torch.stack(coord[::-1], dim=-1)
return coord
然后,您可以使用 torch.argmax
函数获取“展平”张量的索引。
y = x.view(20, -1)
indices = torch.argmax(y)
indices.shape # (20,)
并使用 unravel_index
函数解开索引。
indices = unravel_index(indices, x.shape[-2:])
indices.shape # (20, 2)
已接受的答案仅适用于给定的示例。
tejasvi88 的回答很有趣,但无助于回答原始问题(正如我在那里的评论中所解释的)。
我相信 Francois 的答案是最接近的,因为它处理的是更一般的情况(任意数量的维度)。但是,它与 argmax
无关,并且显示的示例并未说明该函数处理批处理的能力。
所以我将在 Francois 的回答的基础上添加代码以连接到 argmax
。我编写了一个新函数 batch_argmax
,即 returns 批处理中最大值的索引。批次可以按多个维度进行组织。我还包括一些测试用例以供说明:
def batch_argmax(tensor, batch_dim=1):
"""
Assumes that dimensions of tensor up to batch_dim are "batch dimensions"
and returns the indices of the max element of each "batch row".
More precisely, returns tensor `a` such that, for each index v of tensor.shape[:batch_dim], a[v] is
the indices of the max element of tensor[v].
"""
if batch_dim >= len(tensor.shape):
raise NoArgMaxIndices()
batch_shape = tensor.shape[:batch_dim]
non_batch_shape = tensor.shape[batch_dim:]
flat_non_batch_size = prod(non_batch_shape)
tensor_with_flat_non_batch_portion = tensor.reshape(*batch_shape, flat_non_batch_size)
dimension_of_indices = len(non_batch_shape)
# We now have each batch row flattened in the last dimension of tensor_with_flat_non_batch_portion,
# so we can invoke its argmax(dim=-1) method. However, that method throws an exception if the tensor
# is empty. We cover that case first.
if tensor_with_flat_non_batch_portion.numel() == 0:
# If empty, either the batch dimensions or the non-batch dimensions are empty
batch_size = prod(batch_shape)
if batch_size == 0: # if batch dimensions are empty
# return empty tensor of appropriate shape
batch_of_unraveled_indices = torch.ones(*batch_shape, dimension_of_indices).long() # 'ones' is irrelevant as it will be empty
else: # non-batch dimensions are empty, so argmax indices are undefined
raise NoArgMaxIndices()
else: # We actually have elements to maximize, so we search for them
indices_of_non_batch_portion = tensor_with_flat_non_batch_portion.argmax(dim=-1)
batch_of_unraveled_indices = unravel_indices(indices_of_non_batch_portion, non_batch_shape)
if dimension_of_indices == 1:
# above function makes each unraveled index of a n-D tensor a n-long tensor
# however indices of 1D tensors are typically represented by scalars, so we squeeze them in this case.
batch_of_unraveled_indices = batch_of_unraveled_indices.squeeze(dim=-1)
return batch_of_unraveled_indices
class NoArgMaxIndices(BaseException):
def __init__(self):
super(NoArgMaxIndices, self).__init__(
"no argmax indices: batch_argmax requires non-batch shape to be non-empty")
下面是测试:
def test_basic():
# a simple array
tensor = torch.tensor([0, 1, 2, 3, 4])
batch_dim = 0
expected = torch.tensor(4)
run_test(tensor, batch_dim, expected)
# making batch_dim = 1 renders the non-batch portion empty and argmax indices undefined
tensor = torch.tensor([0, 1, 2, 3, 4])
batch_dim = 1
check_that_exception_is_thrown(lambda: batch_argmax(tensor, batch_dim), NoArgMaxIndices)
# now a batch of arrays
tensor = torch.tensor([[1, 2, 3], [6, 5, 4]])
batch_dim = 1
expected = torch.tensor([2, 0])
run_test(tensor, batch_dim, expected)
# Now we have an empty batch with non-batch 3-dim arrays' shape (the arrays are actually non-existent)
tensor = torch.ones(0, 3) # 'ones' is irrelevant since this is empty
batch_dim = 1
# empty batch of the right shape: just the batch dimension 0,since indices of arrays are scalar (0D)
expected = torch.ones(0)
run_test(tensor, batch_dim, expected)
# Now we have an empty batch with non-batch matrices' shape (the matrices are actually non-existent)
tensor = torch.ones(0, 3, 2) # 'ones' is irrelevant since this is empty
batch_dim = 1
# empty batch of the right shape: the batch and two dimension for the indices since we have 2D matrices
expected = torch.ones(0, 2)
run_test(tensor, batch_dim, expected)
# a batch of 2D matrices:
tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
batch_dim = 1
expected = torch.tensor([[1, 0], [1, 2]]) # coordinates of two 6's, one in each 2D matrix
run_test(tensor, batch_dim, expected)
# same as before, but testing that batch_dim supports negative values
tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
batch_dim = -2
expected = torch.tensor([[1, 0], [1, 2]])
run_test(tensor, batch_dim, expected)
# Same data, but a 2-dimensional batch of 1D arrays!
tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
batch_dim = 2
expected = torch.tensor([[2, 0], [1, 2]]) # coordinates of 3, 6, 3, and 6
run_test(tensor, batch_dim, expected)
# same as before, but testing that batch_dim supports negative values
tensor = torch.tensor([[[1, 2, 3], [6, 5, 4]], [[2, 3, 1], [4, 5, 6]]])
batch_dim = -1
expected = torch.tensor([[2, 0], [1, 2]])
run_test(tensor, batch_dim, expected)
def run_test(tensor, batch_dim, expected):
actual = batch_argmax(tensor, batch_dim)
print(f"batch_argmax of {tensor} with batch_dim {batch_dim} is\n{actual}\nExpected:\n{expected}")
assert actual.shape == expected.shape
assert actual.eq(expected).all()
def check_that_exception_is_thrown(thunk, exception_type):
if isinstance(exception_type, BaseException):
raise Exception(f"check_that_exception_is_thrown received an exception instance rather than an exception type: "
f"{exception_type}")
try:
thunk()
raise AssertionError(f"Should have thrown {exception_type}")
except exception_type:
pass
except Exception as e:
raise AssertionError(f"Should have thrown {exception_type} but instead threw {e}")