Tensorflow:张量上的 while 循环
Tensorflow: while loop on tensor
我正在尝试对张量值应用 while 循环。例如,对于变量 "a" 我试图逐渐增加张量的值直到满足特定条件。但是,我不断收到此错误:
ValueError: Shape must be rank 0 but is rank 3 for 'while_12/LoopCond'
(op: 'LoopCond') with input shapes: [3,1,1].
a = array([[[0.76393723]],
[[0.93270312]],
[[0.08361106]]])
a = np.random.random((3,1,1))
a1 = tf.constant(np.float64(a))
i = tf.constant(np.float64(6.14))
c = lambda i: tf.less(i, a1)
b = lambda x: tf.add(x, 0.1)
r = tf.while_loop(c, b, [a1])
tf.while_loop() should return scalar (the tensor of rank 0 is, actually, a scalar - that's what the error message is about). In your example you probably want to make the condition return true
in case if all the numbers in the a1
tensor are less than 6.14
. This can be achieved by tf.reduce_all() (logical AND) and tf.reduce_any()(逻辑或)的第一个参数。
该代码段对我有用:
tf.reset_default_graph()
a = np.random.random_integers(3, size=(3,2))
print(a)
# [[1 1]
# [2 3]
# [1 1]]
a1 = tf.constant(a)
i = 6
# condition returns True till any number in `x` is less than 6
condition = lambda x : tf.reduce_any(tf.less(x, i))
body = lambda x : tf.add(x, 1)
loop = tf.while_loop(
condition,
body,
[a1],
)
with tf.Session() as sess:
result = sess.run(loop)
print(result)
# [[6 6]
# [7 8]
# [6 6]]
# All numbers now are greater than 6
我正在尝试对张量值应用 while 循环。例如,对于变量 "a" 我试图逐渐增加张量的值直到满足特定条件。但是,我不断收到此错误:
ValueError: Shape must be rank 0 but is rank 3 for 'while_12/LoopCond' (op: 'LoopCond') with input shapes: [3,1,1].
a = array([[[0.76393723]],
[[0.93270312]],
[[0.08361106]]])
a = np.random.random((3,1,1))
a1 = tf.constant(np.float64(a))
i = tf.constant(np.float64(6.14))
c = lambda i: tf.less(i, a1)
b = lambda x: tf.add(x, 0.1)
r = tf.while_loop(c, b, [a1])
tf.while_loop() should return scalar (the tensor of rank 0 is, actually, a scalar - that's what the error message is about). In your example you probably want to make the condition return true
in case if all the numbers in the a1
tensor are less than 6.14
. This can be achieved by tf.reduce_all() (logical AND) and tf.reduce_any()(逻辑或)的第一个参数。
该代码段对我有用:
tf.reset_default_graph()
a = np.random.random_integers(3, size=(3,2))
print(a)
# [[1 1]
# [2 3]
# [1 1]]
a1 = tf.constant(a)
i = 6
# condition returns True till any number in `x` is less than 6
condition = lambda x : tf.reduce_any(tf.less(x, i))
body = lambda x : tf.add(x, 1)
loop = tf.while_loop(
condition,
body,
[a1],
)
with tf.Session() as sess:
result = sess.run(loop)
print(result)
# [[6 6]
# [7 8]
# [6 6]]
# All numbers now are greater than 6