如何在tensorflow中使用tf.while_loop()

How to use tf.while_loop() in tensorflow

这是一个通用问题。我发现在 tensorflow 中,在我们构建图形之后,将数据提取到图形中,图形的输出是一个张量。但是在很多情况下,我们需要根据这个输出(是一个tensor)做一些计算,这在tensorflow中是不允许的。

例如,我正在尝试实现一个 RNN,它根据数据自身循环次数 属性。也就是说,我需要用一个 tensor 来判断我是否应该停止(我没有使用 dynamic_rnn 因为在我的设计中,rnn 是高度定制的)。我发现 tf.while_loop(cond,body.....) 可能适合我的实施。但是官方教程太简单了。我不知道如何向 'body' 添加更多功能。谁能给我几个更复杂的例子?

此外,在这种情况下,如果未来的计算基于张量输出(例如:RNN 停止基于输出标准),这是非常常见的情况。有没有更优雅或更好的方法来代替动态图?

是什么阻止您向 body 添加更多功能?您可以在 body 中构建您喜欢的任何复杂计算图,并从封闭图中获取您喜欢的任何输入。此外,在循环之外,您可以随心所欲地使用任何输出 return。从 'whatevers' 的数量可以看出,TensorFlow 的控制流原语在构建时考虑到了很多通用性。下面是另一个 'simple' 示例,希望对您有帮助。

import tensorflow as tf
import numpy as np

def body(x):
    a = tf.random_uniform(shape=[2, 2], dtype=tf.int32, maxval=100)
    b = tf.constant(np.array([[1, 2], [3, 4]]), dtype=tf.int32)
    c = a + b
    return tf.nn.relu(x + c)

def condition(x):
    return tf.reduce_sum(x) < 100

x = tf.Variable(tf.constant(0, shape=[2, 2]))

with tf.Session():
    tf.global_variables_initializer().run()
    result = tf.while_loop(condition, body, [x])
    print(result.eval())