在损失函数中索引 tf 变量

Indexing a tf variable in loss function

我在Tensorflow 1.9.0中自定义了一个损失函数(项目限制无法升级)。我有以下变量,在特征值分解后获得:

# eigw.shape = (?, x)
# eigv.shape = (?, x, y)

现在,我想计算 eigwargmax,这样

amax = tf.argmax(eigw, axis=1, output_type=tf.int32)
# amax.shape = (?,)

我想用 amax 中给出的值索引 eigv,这样

# result.shape = (?, y)

我该如何实现?我尝试直接访问它,但这样做我 运行 遇到了形状不具有相同级别的问题。另外,我尝试使用 tf.while_loop,但我是 tf 的新手,因此我没有成功。

我还有哪些其他选择?我怎样才能最轻松地解决这个问题?

谢谢

在您的特定情况下,您可以使用任何沿轴而不是索引收集最大值的 TensorFlow 函数。

max_value = tf.math.reduce_max(eigw, axis=1)

您可能会在文档中看到任何其他参数。由于 tesnorlfow.org 上没有更多的 TF 1.9 文档,我可以找到 r1.15,它仍然使用静态图。 https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/math/reduce_max

解决方案比预期的要简单。在深入研究我发现的文档后,特征向量已经“按非递减顺序排序”。 所以我只需要取最后一个特征向量。感谢大家的投稿。