如何从预测函数中捕获最高值数组的索引?

How to capture index of highest valued array from predict function?

[[0.12673968 0.15562803 0.03175346 0.6858788 ]] 这就是我的预测函数给出输出的方式,我想获取最高值的索引。 试过这个: pred= pred.tolist() print(max(pred)) index_l=pred.index(max(pred)) print(index_l)

但是好像只输出了0

打印 max(pred) 给出输出: [0.12673968076705933, 0.1556280255317688, 0.031753458082675934, 0.6858788132667542]

网络使用顺序隐藏层(嵌入、BiLSTM、BiLSTM、Dense、Dense)

你只需要使用np.argmax(pred[0])。由于您的 pred 具有形状 [[]] 而不是 [],因此您的元素是其自身的列表。因此,为了获得最大 pred,您需要使用 np.argmax(l[0]).