给定 3d 张量时 argmax 如何工作 - tensorflow

How does argmax work when given a 3d tensor - tensorflow

我想知道 argmax 在给定 3D 张量时如何工作。我知道当它有 2D tesnor 时会发生什么,但 3D 让我很困惑。

示例:

import tensorflow as tf
import numpy as np

sess = tf.Session()

coordinates = np.random.randint(0, 100, size=(3, 3, 2))
coordinates
Out[20]: 
array([[[15, 23],
        [ 3,  1],
        [80, 56]],
       [[98, 95],
        [97, 82],
        [10, 37]],
       [[65, 32],
        [25, 39],
        [54, 68]]])
sess.run([tf.argmax(coordinates, axis=1)])
Out[21]: 
[array([[2, 2],
        [0, 0],
        [0, 2]], dtype=int64)]



tf.argmax returns 根据指定的轴,最大值的索引。指定的轴将被压碎,并返回每个单元的最大值的索引。返回的形状将具有相同的形状,除了将消失的指定轴。我将用 tf.reduce_max 举例,这样我们就可以遵循这些值。

让我们从您的数组开始:

x = np.array([[[15, 23],
               [3, 1],
               [80, 56]],
              [[98, 95],
               [97, 82],
               [10, 37]],
              [[65, 32],
               [25, 39],
               [54, 68]]])

tf.reduce_max(x, axis=0)

            ([[[15, 23],
                          [3, 1],
                                     [80, 56]],
              [[98, 95],               ^
                 ^   ^    [97, 82],
                            ^  ^     [10, 37]],
              [[65, 32],
                          [25, 39],
                                     [54, 68]]]) 
                                           ^    
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[98, 95],
       [97, 82],
       [80, 68]])>

现在tf.reduce_max(x, 1)

            ([[[15, 23], [[98, 95],  [[65, 32],
                            ^   ^       ^
               [3, 1],    [97, 82],   [25, 39],
           
               [80, 56]], [10, 37]],  [54, 68]]])
                 ^   ^                      ^
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[80, 56],
       [98, 95],
       [65, 68]])>

现在tf.reduce_max(x, axis=2)

            ([[[15, 23],
                     ^
               [3, 1],
                ^
               [80, 56]],
                ^   
              [[98, 95],
                 ^
               [97, 82],
                 ^
               [10, 37]],
                     ^
              [[65, 32],
                 ^
               [25, 39],
                     ^
               [54, 68]]])
                     ^
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[23,  3, 80],
       [98, 97, 37],
       [65, 39, 68]])>