如何在 TensorFlow 中取消引用 _ref 张量类型?
How to dereference _ref tensor type in TensorFlow?
如何将引用张量类型转换为值张量类型?
我找到的唯一方法是给张量加一个零。有什么方便的方法吗?
下面assign
是引用类型的张量。如何摆脱 _ref
?
import tensorflow as tf
counter = tf.Variable(0, name="counter")
zero = tf.constant(0)
one = tf.constant(1)
new_counter = tf.add(counter, one)
assign = tf.assign(counter, new_counter) # dtype=int32_ref
result = tf.add(assign, zero) # dtype=int32
result2 = tf.convert_to_tensor(assign) # dtype=int32_ref
# result3 = assign.value() # has no attribute value
一般来说,您应该能够在任何需要 tf.foo
类型张量的地方使用 tf.foo_ref
类型张量。 TensorFlow ops 将隐式取消引用它们的输入参数(除非明确期望引用张量,例如 tf.assign()
)。
取消引用张量的最简单方法是使用tf.identity()
,如下所示:
counter = tf.Variable(0)
assert counter.dtype == tf.int32_ref
counter_val = tf.identity(counter)
assert counter_val.dtype == tf.int32
请注意,这回答了您的问题,但可能具有令人惊讶的语义,因为 tf.identity()
不会复制 底层缓冲区。因此,上例中的counter
和counter_val
共享同一个缓冲区,对counter
的修改会反映在counter_val
:
中
counter = tf.Variable(0)
counter_val = tf.identity(counter) # Take alias before the `assign_add` happens.
counter_update = counter.assign_add(1)
with tf.control_dependencies([counter_update]):
# Force a copy after the `assign_add` happens.
result = counter_val + 0
sess = tf.Session()
sess.run(tf.initialize_all_variables())
print sess.run(result) # ==> 1 (result has effect of `assign_add`)
如何将引用张量类型转换为值张量类型?
我找到的唯一方法是给张量加一个零。有什么方便的方法吗?
下面assign
是引用类型的张量。如何摆脱 _ref
?
import tensorflow as tf
counter = tf.Variable(0, name="counter")
zero = tf.constant(0)
one = tf.constant(1)
new_counter = tf.add(counter, one)
assign = tf.assign(counter, new_counter) # dtype=int32_ref
result = tf.add(assign, zero) # dtype=int32
result2 = tf.convert_to_tensor(assign) # dtype=int32_ref
# result3 = assign.value() # has no attribute value
一般来说,您应该能够在任何需要 tf.foo
类型张量的地方使用 tf.foo_ref
类型张量。 TensorFlow ops 将隐式取消引用它们的输入参数(除非明确期望引用张量,例如 tf.assign()
)。
取消引用张量的最简单方法是使用tf.identity()
,如下所示:
counter = tf.Variable(0)
assert counter.dtype == tf.int32_ref
counter_val = tf.identity(counter)
assert counter_val.dtype == tf.int32
请注意,这回答了您的问题,但可能具有令人惊讶的语义,因为 tf.identity()
不会复制 底层缓冲区。因此,上例中的counter
和counter_val
共享同一个缓冲区,对counter
的修改会反映在counter_val
:
counter = tf.Variable(0)
counter_val = tf.identity(counter) # Take alias before the `assign_add` happens.
counter_update = counter.assign_add(1)
with tf.control_dependencies([counter_update]):
# Force a copy after the `assign_add` happens.
result = counter_val + 0
sess = tf.Session()
sess.run(tf.initialize_all_variables())
print sess.run(result) # ==> 1 (result has effect of `assign_add`)