如何找到 tf.Tensor 中的最大值?

How do I find the max value in tf.Tensor?

如何找到每个元素中的最大值,以便得到 2、4、6、8?

import tensorflow as tf
a = tf.constant([
  [[1, 2]], [[3, 4]],
  [[5, 6]], [[7, 8]]])

我尝试了以下代码:

tf.reduce_max(a, keepdims=True)

但这只给了我 8 作为输出而忽略了其余部分。

您必须像这样将 axis 参数更改为 -1:

import tensorflow as tf
a = tf.constant([
  [[1, 2]], [[3, 4]],
  [[5, 6]], [[7, 8]]])

print(tf.reduce_max(a, axis=-1, keepdims=False))

'''
tf.Tensor(
[[2]
 [4]
 [6]
 [8]], shape=(4, 1), dtype=int32)
'''

因为你有一个 3D 张量并且想要访问最后一个维度。