使用 argmax 获得索引的散点更新张量

scatter update tensor with index obtained using argmax

我正在尝试用另一个值更新张量的最大值,如下所示:

actions = tf.argmax(output, axis=1)
gen_targets = tf.scatter_nd_update(output, actions, q_value)

我收到一个错误:AttributeError: 'Tensor' object has no attribute 'handle'scatter_nd_update

outputactions 是声明为的占位符:

output = tf.placeholder('float', shape=[None, num_action])
reward = tf.placeholder('float', shape=[None])

我做错了什么,实现这个的正确方法是什么?

您正在尝试更新类型为 tf.placeholderoutput 的值。占位符是不可变对象,您无法更新占位符的值。您尝试更新的张量应该是变量类型,例如tf.Variable, in order for tf.scatter_nd_update() 能够更新它的值。 解决此问题的一种方法是创建一个变量,然后使用 tf.assign() 将占位符的值分配给该变量。由于占位符的维度之一是 None 并且在运行期间可能是任意大小,您可能希望将 tf.assign()validate_shape 参数设置为 False,这样占位符的形状不需要与变量的形状相匹配。分配后,var_output 的形状将与通过占位符输入的对象的实际形状匹配。

output = tf.placeholder('float', shape=[None, num_action])
# dummy variable initialization
var_output = tf.Variable(0, dtype=output.dtype)

# assign value of placeholder to the var_output
var_output = tf.assign(var_output, output, validate_shape=False)
# ...
gen_targets = tf.scatter_nd_update(var_output, actions, q_value)
# ...
sess.run(gen_targets, feed_dict={output: feed_your_placeholder_here})