pytorch Argmax 冲突情况下的索引选择
index selection in case of conflict in pytorch Argmax
我一直在尝试学习张量运算,但这个让我陷入了困境。
假设我有一个张量 t:
t = torch.tensor([
[1,0,0,2],
[0,3,3,0],
[4,0,0,5]
], dtype = torch.float32)
现在这是一个 2 阶张量,我们可以为每个 rank/dimension 应用 argmax。
假设我们将它应用于 dim = 1
t.max(dim = 1)
(tensor([2., 3., 5.]), tensor([3, 2, 3]))
现在我们可以看到结果与预期一致,沿 dim =1 的张量具有 2,3 和 5 作为最大元素。但是在3上有冲突,有两个值完全相似
它是如何解决的?它是任意选择的吗?像L-R一样,索引值高的选择有顺序吗?
我很感激任何关于如何解决这个问题的见解!
这是一个我自己偶然发现的好问题。最简单的答案是,无法保证 torch.argmax
(或 torch.max(x, dim=k)
,当指定 dim 时,它们也是 return 的索引)将 return 始终如一地使用相同的索引。相反,它将 return argmax 值的任何有效索引 ,可能是随机的。正如 this thread in the official forum 所讨论的,这被认为是期望的行为。 (我知道我刚才读到的另一个线程使这一点更加明确,但我找不到了)。
话虽如此,由于这种行为对我的用例来说是不可接受的,我编写了以下函数来查找最左边和最右边的索引(请注意 condition
是您传入的函数对象):
def __consistent_args(input, condition, indices):
assert len(input.shape) == 2, 'only works for batch x dim tensors along the dim axis'
mask = condition(input).float() * indices.unsqueeze(0).expand_as(input)
return torch.argmax(mask, dim=1)
def consistent_find_leftmost(input, condition):
indices = torch.arange(input.size(1), 0, -1, dtype=torch.float, device=input.device)
return __consistent_args(input, condition, indices)
def consistent_find_rightmost(input, condition):
indices = torch.arange(0, input.size(1), 1, dtype=torch.float, device=input.device)
return __consistent_args(input, condition, indices)
# one example:
consistent_find_leftmost(torch.arange(10).unsqueeze(0), lambda x: x>5)
# will return:
# tensor([6])
希望他们能帮到你! (哦,如果您有更好的实现方式,请告诉我)
我一直在尝试学习张量运算,但这个让我陷入了困境。
假设我有一个张量 t:
t = torch.tensor([
[1,0,0,2],
[0,3,3,0],
[4,0,0,5]
], dtype = torch.float32)
现在这是一个 2 阶张量,我们可以为每个 rank/dimension 应用 argmax。 假设我们将它应用于 dim = 1
t.max(dim = 1)
(tensor([2., 3., 5.]), tensor([3, 2, 3]))
现在我们可以看到结果与预期一致,沿 dim =1 的张量具有 2,3 和 5 作为最大元素。但是在3上有冲突,有两个值完全相似
它是如何解决的?它是任意选择的吗?像L-R一样,索引值高的选择有顺序吗?
我很感激任何关于如何解决这个问题的见解!
这是一个我自己偶然发现的好问题。最简单的答案是,无法保证 torch.argmax
(或 torch.max(x, dim=k)
,当指定 dim 时,它们也是 return 的索引)将 return 始终如一地使用相同的索引。相反,它将 return argmax 值的任何有效索引 ,可能是随机的。正如 this thread in the official forum 所讨论的,这被认为是期望的行为。 (我知道我刚才读到的另一个线程使这一点更加明确,但我找不到了)。
话虽如此,由于这种行为对我的用例来说是不可接受的,我编写了以下函数来查找最左边和最右边的索引(请注意 condition
是您传入的函数对象):
def __consistent_args(input, condition, indices):
assert len(input.shape) == 2, 'only works for batch x dim tensors along the dim axis'
mask = condition(input).float() * indices.unsqueeze(0).expand_as(input)
return torch.argmax(mask, dim=1)
def consistent_find_leftmost(input, condition):
indices = torch.arange(input.size(1), 0, -1, dtype=torch.float, device=input.device)
return __consistent_args(input, condition, indices)
def consistent_find_rightmost(input, condition):
indices = torch.arange(0, input.size(1), 1, dtype=torch.float, device=input.device)
return __consistent_args(input, condition, indices)
# one example:
consistent_find_leftmost(torch.arange(10).unsqueeze(0), lambda x: x>5)
# will return:
# tensor([6])
希望他们能帮到你! (哦,如果您有更好的实现方式,请告诉我)