PyTorch torch.max 多个维度

PyTorch torch.max over multiple dimensions

有像这样的张量:x.shape = [3, 2, 2]

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

我需要 .max() 在第 2 和第 3 个维度上。我希望像这样 [-0.2632, -0.1453, -0.0274] 作为输出。我尝试使用:x.max(dim=(1,2)),但这会导致错误。

现在,您可以做到这一点。 PR was merged(8 月 28 日),现在在每晚版本中可用。

只需使用torch.amax():

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

print(torch.amax(x, dim=(1, 2)))

# Output:
# >>> tensor([-0.2632, -0.1453, -0.0274])

原答案

截至今天(2020 年 4 月 11 日),无法在 PyTorch 中对多个维度执行 .min().max()。有一个关于它的 open issue,您可以关注它并查看它是否得到实施。您的解决方法是:

import torch

x = torch.tensor([
    [[-0.3000, -0.2926],[-0.2705, -0.2632]],
    [[-0.1821, -0.1747],[-0.1526, -0.1453]],
    [[-0.0642, -0.0568],[-0.0347, -0.0274]]
])

print(x.view(x.size(0), -1).max(dim=-1))

# output:
# >>> values=tensor([-0.2632, -0.1453, -0.0274]),
# >>> indices=tensor([3, 3, 3]))

因此,如果您只需要以下值:x.view(x.size(0), -1).max(dim=-1).values

如果x不是一个连续的张量,那么.view()就会失败。在这种情况下,您应该改用 .reshape()


更新 2020 年 8 月 26 日

此功能正在 PR#43092 中实现,函数将被称为 aminamax。他们只会 return 值。这可能很快就会被合并,所以当你阅读这篇文章时,你可能能够在夜间构建中访问这些功能:)祝你玩得开心。

虽然 解决了这个具体问题,但我认为添加一些解释可能会帮助每个人阐明这里使用的技巧,以便它可以适用于 (m) 任何其他维度。

让我们从检查输入张量的形状开始 x:

In [58]: x.shape   
Out[58]: torch.Size([3, 2, 2])

所以,我们有一个形状为 (3, 2, 2) 的 3D 张量。现在,根据 OP 的问题,我们需要计算张量中沿 1st 和 2nd 维度的值的 maximum .在撰写本文时,torch.max()dim 参数仅支持 int。所以,我们不能使用元组。因此,我们将使用以下技巧,我称之为

Flatten & Max 技巧:因为我们想在 1st 和 2[=54 上计算 max =]nd 维度,我们会将这两个维度展平为一个维度,并保持第 0th 维度不变。这正是正在发生的事情:

In [61]: x.flatten().reshape(x.shape[0], -1).shape   
Out[61]: torch.Size([3, 4])   # 2*2 = 4

所以,现在我们已经将 3D 张量缩小为 2D 张量(即矩阵)。

In [62]: x.flatten().reshape(x.shape[0], -1) 
Out[62]:
tensor([[-0.3000, -0.2926, -0.2705, -0.2632],
        [-0.1821, -0.1747, -0.1526, -0.1453],
        [-0.0642, -0.0568, -0.0347, -0.0274]])

现在,我们可以简单地将 max 应用于第 1st 个维度(即在这种情况下,第一个维度也是最后一个维度),因为展平的维度居住在那个维度。

In [65]: x.flatten().reshape(x.shape[0], -1).max(dim=1)    # or: `dim = -1`
Out[65]: 
torch.return_types.max(
values=tensor([-0.2632, -0.1453, -0.0274]),
indices=tensor([3, 3, 3]))

由于矩阵中有 3 行,因此我们在结果张量中得到了 3 个值。


现在,另一方面,如果你想在第 0 和第 1 个维度上计算 max,你会做:

In [80]: x.flatten().reshape(-1, x.shape[-1]).shape 
Out[80]: torch.Size([6, 2])    # 3*2 = 6

In [79]: x.flatten().reshape(-1, x.shape[-1]) 
Out[79]: 
tensor([[-0.3000, -0.2926],
        [-0.2705, -0.2632],
        [-0.1821, -0.1747],
        [-0.1526, -0.1453],
        [-0.0642, -0.0568],
        [-0.0347, -0.0274]])

现在,我们可以简单地将 max 应用于第 0 维度,因为这是我们展平的结果。 ((另外,从我们原来的形状(3, 2, 2),在前两个维度上取最大值后,我们应该得到两个值作为结果。)

In [82]: x.flatten().reshape(-1, x.shape[-1]).max(dim=0) 
Out[82]: 
torch.return_types.max(
values=tensor([-0.0347, -0.0274]),
indices=tensor([5, 5]))

以类似的方式,您可以将此方法应用于多维和其他缩减函数,例如 min


注意:我遵循基于 0 的维度 (0, 1, 2, 3, ...) 的术语只是为了与 PyTorch 用法和代码保持一致。

如果你只想使用torch.max()函数获取二维张量中最大条目的索引,你可以这样做:

max_i_vals, max_i_indices = torch.max(x, 0)
print('max_i_vals, max_i_indices: ', max_i_vals, max_i_indices)
max_j_index = torch.max(max_i_vals, 0)[1]
print('max_j_index: ', max_j_index)
max_index = [max_i_indices[max_j_index], max_j_index]
print('max_index: ', max_index)

在测试中,上面为我打印出来:

max_i_vals: tensor([0.7930, 0.7144, 0.6985, 0.7349, 0.9162, 0.5584, 1.4777, 0.8047, 0.9008, 1.0169, 0.6705, 0.9034, 1.1159, 0.8852, 1.0353], grad_fn=\<MaxBackward0>)   
max_i_indices: tensor([ 5,  8, 10,  6, 13, 14,  5,  6,  6,  6, 13,  4, 13, 13, 11])  
max_j_index:  tensor(6)  
max_index:  [tensor(5), tensor(6)]

这种方法可以扩展到 3 个维度。虽然不像这个 post 中的其他答案那样视觉上令人愉悦,但这个答案表明仅使用 torch.max() 函数就可以解决问题(尽管我同意 built-in 支持 torch.max() 在多个维度上将是一个福音)。

跟进
我偶然发现了 similar question in the PyTorch forums 并且 poster ptrblck 提供了这行代码作为获取张量 x:

中最大条目的索引的解决方案
x = (x==torch.max(x)).nonzero()

这个 one-liner 不仅不需要调整代码就可以使用 N-dimensional 张量,而且它也比我上面写的方法快得多(至少 2:1 ratio)并且比接受的答案(大约 3:2 ratio)更快,根据我的基准。

如果您使用 torch <= 1.10,请使用此自定义 max 函数

def torch_max(x,dim):
    s1 = [i for i in range(len(x.shape)) if i not in dim] 
    s2 = [i for i in range(len(x.shape)) if i in dim]
    x2 = x.permute(tuple(s1+s2))
    s = [d for (i,d) in enumerate(x.shape) if i not in dim] + [-1]
    x2 = torch.reshape(x2, tuple(s))
    max,_ = x2.max(-1)
    return max 

然后运行喜欢

print(torch_max(x, dim=(1, 2)))