在 argmax 上使用 gather 与采用 max 不同

using gather on argmax is different than taking max

我正在尝试学习在 tensorflow 上训练双 DQN 算法,但它不起作用。为了确保一切正常,我想测试一些东西。我想确保在 argmax 上使用 tf.gather 与采用最大值完全相同:假设我有一个名为 target_network:

的网络

首先让我们取最大值:

next_qvalues_target1 = target_network.get_symbolic_qvalues(next_obs_ph) #returns tensor of qvalues
next_state_values_target1 = tf.reduce_max(next_qvalues_target1, axis=1)

让我们以不同的方式尝试一下——使用 argmax 并收集:

next_qvalues_target2 = target_network.get_symbolic_qvalues(next_obs_ph) #returns same tensor of qvalues
chosen_action = tf.argmax(next_qvalues_target2, axis=1)
next_state_values_target2 = tf.gather(next_qvalues_target2, chosen_action)

diff = tf.reduce_sum(next_state_values_target1) - tf.reduce_sum(next_state_values_target2)

next_state_values_target2 和 next_state_values_target1 应该是完全相同的。所以 运行 会话应该输出 diff = 。但事实并非如此。

我错过了什么?

谢谢。

找出问题所在。所选动作的形状为 (n, 1),因此我认为对 (n, 4) 的变量使用 gather 会得到形状 (n, 1) 的结果。事实证明这不是真的。我需要将 chosen_action 变成形状 (n, 2) 的变量 - 而不是 [action1, action2, action3...] 我需要它是 [[1, action1], [2, action2] , [3, action3]...] 并使用 gather_nd 能够从 next_qvalues_target2 中获取特定元素而不是收集,因为收集需要完整的行。