PyTorch torch.max 多个维度
PyTorch torch.max over multiple dimensions
有像这样的张量:x.shape = [3, 2, 2]
。
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
我需要 .max()
在第 2 和第 3 个维度上。我希望像这样 [-0.2632, -0.1453, -0.0274]
作为输出。我尝试使用:x.max(dim=(1,2))
,但这会导致错误。
现在,您可以做到这一点。 PR was merged(8 月 28 日),现在在每晚版本中可用。
只需使用torch.amax()
:
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
print(torch.amax(x, dim=(1, 2)))
# Output:
# >>> tensor([-0.2632, -0.1453, -0.0274])
原答案
截至今天(2020 年 4 月 11 日),无法在 PyTorch 中对多个维度执行 .min()
或 .max()
。有一个关于它的 open issue,您可以关注它并查看它是否得到实施。您的解决方法是:
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
print(x.view(x.size(0), -1).max(dim=-1))
# output:
# >>> values=tensor([-0.2632, -0.1453, -0.0274]),
# >>> indices=tensor([3, 3, 3]))
因此,如果您只需要以下值:x.view(x.size(0), -1).max(dim=-1).values
。
如果x
不是一个连续的张量,那么.view()
就会失败。在这种情况下,您应该改用 .reshape()
。
更新 2020 年 8 月 26 日
此功能正在 PR#43092 中实现,函数将被称为 amin
和 amax
。他们只会 return 值。这可能很快就会被合并,所以当你阅读这篇文章时,你可能能够在夜间构建中访问这些功能:)祝你玩得开心。
虽然 解决了这个具体问题,但我认为添加一些解释可能会帮助每个人阐明这里使用的技巧,以便它可以适用于 (m) 任何其他维度。
让我们从检查输入张量的形状开始 x
:
In [58]: x.shape
Out[58]: torch.Size([3, 2, 2])
所以,我们有一个形状为 (3, 2, 2)
的 3D 张量。现在,根据 OP 的问题,我们需要计算张量中沿 1st 和 2nd 维度的值的 maximum
.在撰写本文时,torch.max()
的 dim
参数仅支持 int
。所以,我们不能使用元组。因此,我们将使用以下技巧,我称之为
Flatten & Max 技巧:因为我们想在 1st 和 2[=54 上计算 max
=]nd 维度,我们会将这两个维度展平为一个维度,并保持第 0th 维度不变。这正是正在发生的事情:
In [61]: x.flatten().reshape(x.shape[0], -1).shape
Out[61]: torch.Size([3, 4]) # 2*2 = 4
所以,现在我们已经将 3D 张量缩小为 2D 张量(即矩阵)。
In [62]: x.flatten().reshape(x.shape[0], -1)
Out[62]:
tensor([[-0.3000, -0.2926, -0.2705, -0.2632],
[-0.1821, -0.1747, -0.1526, -0.1453],
[-0.0642, -0.0568, -0.0347, -0.0274]])
现在,我们可以简单地将 max
应用于第 1st 个维度(即在这种情况下,第一个维度也是最后一个维度),因为展平的维度居住在那个维度。
In [65]: x.flatten().reshape(x.shape[0], -1).max(dim=1) # or: `dim = -1`
Out[65]:
torch.return_types.max(
values=tensor([-0.2632, -0.1453, -0.0274]),
indices=tensor([3, 3, 3]))
由于矩阵中有 3 行,因此我们在结果张量中得到了 3 个值。
现在,另一方面,如果你想在第 0 和第 1 个维度上计算 max
,你会做:
In [80]: x.flatten().reshape(-1, x.shape[-1]).shape
Out[80]: torch.Size([6, 2]) # 3*2 = 6
In [79]: x.flatten().reshape(-1, x.shape[-1])
Out[79]:
tensor([[-0.3000, -0.2926],
[-0.2705, -0.2632],
[-0.1821, -0.1747],
[-0.1526, -0.1453],
[-0.0642, -0.0568],
[-0.0347, -0.0274]])
现在,我们可以简单地将 max
应用于第 0 个 维度,因为这是我们展平的结果。 ((另外,从我们原来的形状(3, 2, 2
),在前两个维度上取最大值后,我们应该得到两个值作为结果。)
In [82]: x.flatten().reshape(-1, x.shape[-1]).max(dim=0)
Out[82]:
torch.return_types.max(
values=tensor([-0.0347, -0.0274]),
indices=tensor([5, 5]))
以类似的方式,您可以将此方法应用于多维和其他缩减函数,例如 min
。
注意:我遵循基于 0 的维度 (0, 1, 2, 3, ...
) 的术语只是为了与 PyTorch 用法和代码保持一致。
如果你只想使用torch.max()
函数获取二维张量中最大条目的索引,你可以这样做:
max_i_vals, max_i_indices = torch.max(x, 0)
print('max_i_vals, max_i_indices: ', max_i_vals, max_i_indices)
max_j_index = torch.max(max_i_vals, 0)[1]
print('max_j_index: ', max_j_index)
max_index = [max_i_indices[max_j_index], max_j_index]
print('max_index: ', max_index)
在测试中,上面为我打印出来:
max_i_vals: tensor([0.7930, 0.7144, 0.6985, 0.7349, 0.9162, 0.5584, 1.4777, 0.8047, 0.9008, 1.0169, 0.6705, 0.9034, 1.1159, 0.8852, 1.0353], grad_fn=\<MaxBackward0>)
max_i_indices: tensor([ 5, 8, 10, 6, 13, 14, 5, 6, 6, 6, 13, 4, 13, 13, 11])
max_j_index: tensor(6)
max_index: [tensor(5), tensor(6)]
这种方法可以扩展到 3 个维度。虽然不像这个 post 中的其他答案那样视觉上令人愉悦,但这个答案表明仅使用 torch.max()
函数就可以解决问题(尽管我同意 built-in 支持 torch.max()
在多个维度上将是一个福音)。
跟进
我偶然发现了 similar question in the PyTorch forums 并且 poster ptrblck 提供了这行代码作为获取张量 x:
中最大条目的索引的解决方案
x = (x==torch.max(x)).nonzero()
这个 one-liner 不仅不需要调整代码就可以使用 N-dimensional 张量,而且它也比我上面写的方法快得多(至少 2:1 ratio)并且比接受的答案(大约 3:2 ratio)更快,根据我的基准。
如果您使用 torch <= 1.10,请使用此自定义 max 函数
def torch_max(x,dim):
s1 = [i for i in range(len(x.shape)) if i not in dim]
s2 = [i for i in range(len(x.shape)) if i in dim]
x2 = x.permute(tuple(s1+s2))
s = [d for (i,d) in enumerate(x.shape) if i not in dim] + [-1]
x2 = torch.reshape(x2, tuple(s))
max,_ = x2.max(-1)
return max
然后运行喜欢
print(torch_max(x, dim=(1, 2)))
有像这样的张量:x.shape = [3, 2, 2]
。
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
我需要 .max()
在第 2 和第 3 个维度上。我希望像这样 [-0.2632, -0.1453, -0.0274]
作为输出。我尝试使用:x.max(dim=(1,2))
,但这会导致错误。
现在,您可以做到这一点。 PR was merged(8 月 28 日),现在在每晚版本中可用。
只需使用torch.amax()
:
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
print(torch.amax(x, dim=(1, 2)))
# Output:
# >>> tensor([-0.2632, -0.1453, -0.0274])
原答案
截至今天(2020 年 4 月 11 日),无法在 PyTorch 中对多个维度执行 .min()
或 .max()
。有一个关于它的 open issue,您可以关注它并查看它是否得到实施。您的解决方法是:
import torch
x = torch.tensor([
[[-0.3000, -0.2926],[-0.2705, -0.2632]],
[[-0.1821, -0.1747],[-0.1526, -0.1453]],
[[-0.0642, -0.0568],[-0.0347, -0.0274]]
])
print(x.view(x.size(0), -1).max(dim=-1))
# output:
# >>> values=tensor([-0.2632, -0.1453, -0.0274]),
# >>> indices=tensor([3, 3, 3]))
因此,如果您只需要以下值:x.view(x.size(0), -1).max(dim=-1).values
。
如果x
不是一个连续的张量,那么.view()
就会失败。在这种情况下,您应该改用 .reshape()
。
更新 2020 年 8 月 26 日
此功能正在 PR#43092 中实现,函数将被称为 amin
和 amax
。他们只会 return 值。这可能很快就会被合并,所以当你阅读这篇文章时,你可能能够在夜间构建中访问这些功能:)祝你玩得开心。
虽然
让我们从检查输入张量的形状开始 x
:
In [58]: x.shape
Out[58]: torch.Size([3, 2, 2])
所以,我们有一个形状为 (3, 2, 2)
的 3D 张量。现在,根据 OP 的问题,我们需要计算张量中沿 1st 和 2nd 维度的值的 maximum
.在撰写本文时,torch.max()
的 dim
参数仅支持 int
。所以,我们不能使用元组。因此,我们将使用以下技巧,我称之为
Flatten & Max 技巧:因为我们想在 1st 和 2[=54 上计算 max
=]nd 维度,我们会将这两个维度展平为一个维度,并保持第 0th 维度不变。这正是正在发生的事情:
In [61]: x.flatten().reshape(x.shape[0], -1).shape
Out[61]: torch.Size([3, 4]) # 2*2 = 4
所以,现在我们已经将 3D 张量缩小为 2D 张量(即矩阵)。
In [62]: x.flatten().reshape(x.shape[0], -1)
Out[62]:
tensor([[-0.3000, -0.2926, -0.2705, -0.2632],
[-0.1821, -0.1747, -0.1526, -0.1453],
[-0.0642, -0.0568, -0.0347, -0.0274]])
现在,我们可以简单地将 max
应用于第 1st 个维度(即在这种情况下,第一个维度也是最后一个维度),因为展平的维度居住在那个维度。
In [65]: x.flatten().reshape(x.shape[0], -1).max(dim=1) # or: `dim = -1`
Out[65]:
torch.return_types.max(
values=tensor([-0.2632, -0.1453, -0.0274]),
indices=tensor([3, 3, 3]))
由于矩阵中有 3 行,因此我们在结果张量中得到了 3 个值。
现在,另一方面,如果你想在第 0 和第 1 个维度上计算 max
,你会做:
In [80]: x.flatten().reshape(-1, x.shape[-1]).shape
Out[80]: torch.Size([6, 2]) # 3*2 = 6
In [79]: x.flatten().reshape(-1, x.shape[-1])
Out[79]:
tensor([[-0.3000, -0.2926],
[-0.2705, -0.2632],
[-0.1821, -0.1747],
[-0.1526, -0.1453],
[-0.0642, -0.0568],
[-0.0347, -0.0274]])
现在,我们可以简单地将 max
应用于第 0 个 维度,因为这是我们展平的结果。 ((另外,从我们原来的形状(3, 2, 2
),在前两个维度上取最大值后,我们应该得到两个值作为结果。)
In [82]: x.flatten().reshape(-1, x.shape[-1]).max(dim=0)
Out[82]:
torch.return_types.max(
values=tensor([-0.0347, -0.0274]),
indices=tensor([5, 5]))
以类似的方式,您可以将此方法应用于多维和其他缩减函数,例如 min
。
注意:我遵循基于 0 的维度 (0, 1, 2, 3, ...
) 的术语只是为了与 PyTorch 用法和代码保持一致。
如果你只想使用torch.max()
函数获取二维张量中最大条目的索引,你可以这样做:
max_i_vals, max_i_indices = torch.max(x, 0)
print('max_i_vals, max_i_indices: ', max_i_vals, max_i_indices)
max_j_index = torch.max(max_i_vals, 0)[1]
print('max_j_index: ', max_j_index)
max_index = [max_i_indices[max_j_index], max_j_index]
print('max_index: ', max_index)
在测试中,上面为我打印出来:
max_i_vals: tensor([0.7930, 0.7144, 0.6985, 0.7349, 0.9162, 0.5584, 1.4777, 0.8047, 0.9008, 1.0169, 0.6705, 0.9034, 1.1159, 0.8852, 1.0353], grad_fn=\<MaxBackward0>)
max_i_indices: tensor([ 5, 8, 10, 6, 13, 14, 5, 6, 6, 6, 13, 4, 13, 13, 11])
max_j_index: tensor(6)
max_index: [tensor(5), tensor(6)]
这种方法可以扩展到 3 个维度。虽然不像这个 post 中的其他答案那样视觉上令人愉悦,但这个答案表明仅使用 torch.max()
函数就可以解决问题(尽管我同意 built-in 支持 torch.max()
在多个维度上将是一个福音)。
跟进
我偶然发现了 similar question in the PyTorch forums 并且 poster ptrblck 提供了这行代码作为获取张量 x:
x = (x==torch.max(x)).nonzero()
这个 one-liner 不仅不需要调整代码就可以使用 N-dimensional 张量,而且它也比我上面写的方法快得多(至少 2:1 ratio)并且比接受的答案(大约 3:2 ratio)更快,根据我的基准。
如果您使用 torch <= 1.10,请使用此自定义 max 函数
def torch_max(x,dim):
s1 = [i for i in range(len(x.shape)) if i not in dim]
s2 = [i for i in range(len(x.shape)) if i in dim]
x2 = x.permute(tuple(s1+s2))
s = [d for (i,d) in enumerate(x.shape) if i not in dim] + [-1]
x2 = torch.reshape(x2, tuple(s))
max,_ = x2.max(-1)
return max
然后运行喜欢
print(torch_max(x, dim=(1, 2)))