为什么 dim=1 return 在 torch.argmax 中行索引?
Why does dim=1 return row indices in torch.argmax?
我正在研究 PyTorch 的 argmax
函数,定义为:
torch.argmax(input, dim=None, keepdim=False)
考虑一个例子
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))
此处当我使用 dim=1 而不是搜索列向量时,该函数搜索行向量,如下所示。
print(a) :
tensor([[-1.7739, 0.8073, 0.0472, -0.4084],
[ 0.6378, 0.6575, -1.2970, -0.0625],
[ 1.7970, -1.3463, 0.9011, -0.8704],
[ 1.5639, 0.7123, 0.0385, 1.8410]])
print(torch.argmax(a, dim=1))
tensor([1, 1, 0, 3])
就我的假设而言,dim = 0 表示行,dim =1 表示列。
是时候正确理解 axis
或 dim
参数在 PyTorch 中的工作原理了:
理解了上图后,下面的例子应该就明白了:
|
v
dim-0 ---> -----> dim-1 ------> -----> --------> dim-1
| [[-1.7739, 0.8073, 0.0472, -0.4084],
v [ 0.6378, 0.6575, -1.2970, -0.0625],
| [ 1.7970, -1.3463, 0.9011, -0.8704],
v [ 1.5639, 0.7123, 0.0385, 1.8410]]
|
v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])
注意:dim
('dimension'的缩写)相当于[=25=的手电筒]'axis' 在 NumPy 中。
我正在研究 PyTorch 的 argmax
函数,定义为:
torch.argmax(input, dim=None, keepdim=False)
考虑一个例子
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))
此处当我使用 dim=1 而不是搜索列向量时,该函数搜索行向量,如下所示。
print(a) :
tensor([[-1.7739, 0.8073, 0.0472, -0.4084],
[ 0.6378, 0.6575, -1.2970, -0.0625],
[ 1.7970, -1.3463, 0.9011, -0.8704],
[ 1.5639, 0.7123, 0.0385, 1.8410]])
print(torch.argmax(a, dim=1))
tensor([1, 1, 0, 3])
就我的假设而言,dim = 0 表示行,dim =1 表示列。
是时候正确理解 axis
或 dim
参数在 PyTorch 中的工作原理了:
理解了上图后,下面的例子应该就明白了:
| v dim-0 ---> -----> dim-1 ------> -----> --------> dim-1 | [[-1.7739, 0.8073, 0.0472, -0.4084], v [ 0.6378, 0.6575, -1.2970, -0.0625], | [ 1.7970, -1.3463, 0.9011, -0.8704], v [ 1.5639, 0.7123, 0.0385, 1.8410]] | v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])
注意:dim
('dimension'的缩写)相当于[=25=的手电筒]'axis' 在 NumPy 中。