使用 argmax 获得索引的散点更新张量
scatter update tensor with index obtained using argmax
我正在尝试用另一个值更新张量的最大值,如下所示:
actions = tf.argmax(output, axis=1)
gen_targets = tf.scatter_nd_update(output, actions, q_value)
我收到一个错误:AttributeError: 'Tensor' object has no attribute 'handle'
在 scatter_nd_update
。
output
和 actions
是声明为的占位符:
output = tf.placeholder('float', shape=[None, num_action])
reward = tf.placeholder('float', shape=[None])
我做错了什么,实现这个的正确方法是什么?
您正在尝试更新类型为 tf.placeholder
的 output
的值。占位符是不可变对象,您无法更新占位符的值。您尝试更新的张量应该是变量类型,例如tf.Variable, in order for tf.scatter_nd_update() 能够更新它的值。
解决此问题的一种方法是创建一个变量,然后使用 tf.assign() 将占位符的值分配给该变量。由于占位符的维度之一是 None
并且在运行期间可能是任意大小,您可能希望将 tf.assign()
的 validate_shape
参数设置为 False
,这样占位符的形状不需要与变量的形状相匹配。分配后,var_output
的形状将与通过占位符输入的对象的实际形状匹配。
output = tf.placeholder('float', shape=[None, num_action])
# dummy variable initialization
var_output = tf.Variable(0, dtype=output.dtype)
# assign value of placeholder to the var_output
var_output = tf.assign(var_output, output, validate_shape=False)
# ...
gen_targets = tf.scatter_nd_update(var_output, actions, q_value)
# ...
sess.run(gen_targets, feed_dict={output: feed_your_placeholder_here})
我正在尝试用另一个值更新张量的最大值,如下所示:
actions = tf.argmax(output, axis=1)
gen_targets = tf.scatter_nd_update(output, actions, q_value)
我收到一个错误:AttributeError: 'Tensor' object has no attribute 'handle'
在 scatter_nd_update
。
output
和 actions
是声明为的占位符:
output = tf.placeholder('float', shape=[None, num_action])
reward = tf.placeholder('float', shape=[None])
我做错了什么,实现这个的正确方法是什么?
您正在尝试更新类型为 tf.placeholder
的 output
的值。占位符是不可变对象,您无法更新占位符的值。您尝试更新的张量应该是变量类型,例如tf.Variable, in order for tf.scatter_nd_update() 能够更新它的值。
解决此问题的一种方法是创建一个变量,然后使用 tf.assign() 将占位符的值分配给该变量。由于占位符的维度之一是 None
并且在运行期间可能是任意大小,您可能希望将 tf.assign()
的 validate_shape
参数设置为 False
,这样占位符的形状不需要与变量的形状相匹配。分配后,var_output
的形状将与通过占位符输入的对象的实际形状匹配。
output = tf.placeholder('float', shape=[None, num_action])
# dummy variable initialization
var_output = tf.Variable(0, dtype=output.dtype)
# assign value of placeholder to the var_output
var_output = tf.assign(var_output, output, validate_shape=False)
# ...
gen_targets = tf.scatter_nd_update(var_output, actions, q_value)
# ...
sess.run(gen_targets, feed_dict={output: feed_your_placeholder_here})