Pytorch:解释 torch.argmax

Pytorch: explain torch.argmax

您好,我有以下代码:

import torch
x = torch.zeros(1,8,4,576) # create a 4 dimensional tensor
x[0,4,2,333] = 1.0 # put on 1 on a random spot

# I want to find the index of the highest value (0,4,2,333)
print(x.argmax()) # this should return the index

这个returns

tensor(10701)

这个10701有什么意义?

如何获取实际索引 0,4,2,333?

4维数组中的数据在内存中线性存储,argmax()returns这个平面表示的对应索引

Numpy 有解开索引的函数(从平面数组索引转换为对应的多维索引)。

import numpy as np
np.unravel_index(10701, (1,8,4,576))