张量流:检查标量布尔张量是否为真

tensorflow: check if a scalar boolean tensor is True

我想使用占位符控制函数的执行,但一直收到错误 "Using a tf.Tensor as a Python bool is not allowed"。这是产生此错误的代码:

import tensorflow as tf
def foo(c):
  if c:
    print('This is true')
    #heavy code here
    return 10
  else:
    print('This is false')
    #different code here
    return 0

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()

我把 if c 改成 if c is not None 很不走运。那么如何通过打开和关闭占位符 a 来控制 foo

更新:正如@nessuno 和@nemo 指出的,我们必须使用tf.cond 而不是if..else。我的问题的答案是像这样重新设计我的功能:

import tensorflow as tf
def foo(c):
  return tf.cond(c, func1, func2)

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = foo(a)
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close() 

您必须使用 tf.cond 在图中定义条件运算并更改,从而改变张量的流动。

import tensorflow as tf

a = tf.placeholder(tf.bool)  #placeholder for a single boolean value
b = tf.cond(tf.equal(a, tf.constant(True)), lambda: tf.constant(10), lambda: tf.constant(0))
sess = tf.InteractiveSession()
res = sess.run(b, feed_dict = {a: True})
sess.close()
print(res)

10

实际的执行不是在 Python 中完成的,而是在 TensorFlow 后端中完成的,您提供了它应该执行的计算图。这意味着您要应用的每个条件和流量控制都必须被表述为计算图中的一个节点。

对于if条件有cond操作:

b = tf.cond(c, 
           lambda: tf.constant(10), 
           lambda: tf.constant(0))

更简单的解决方法:

In [50]: a = tf.placeholder(tf.bool)                                                                                                                                                                                 

In [51]: is_true = tf.count_nonzero([a])                                                                                                                                                                             

In [52]: sess.run(is_true, {a: True})                                                                                                                                                                                
Out[52]: 1

In [53]: sess.run(is_true, {a: False})
Out[53]: 0