Tensorflow:跨批次的 2x2 矩阵中的最大值索引

Tensorflow: Indices of max value in 2x2 matrix across batch

如果我有一批矩阵,形状为 (?, 600, 600),我将如何检索批次中每个矩阵中最大值的行和列索引?这样我的行和列 return 矩阵都是形状(?)(行 return 矩阵具有批处理中每个示例的最大值行的索引,并且与 col return矩阵)。

谢谢!

你可以重塑 + argmax。类似于:

x = tf.reshape(matrix, [tf.shape(matrix, 0), -1])
indices = tf.argmax(x, axis=1)  # this gives you indices from 0 to 600^2
col_indices = indices / 600
row_indices = indices % 600
final_indices = tf.transpose(tf.stack(col_indices, row_indices))