递归分配给 Tensorflow 中的变量切片
Recursively Assign to Variable Slices in Tensorflow
我想递归地为 Tensorflow (1.15) 变量中的切片赋值。
为了说明,这个有效:
def test_loss():
m = tf.Variable(1)
n = 3
A = tf.Variable(tf.zeros([10., 20., 30.]))
B = tf.Variable(tf.ones([10., 20., 30.]))
A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
return 1
test_loss()
Out: 1
然后我尝试了:
def test_loss():
m = tf.Variable(1)
#n = 3
A = tf.Variable(tf.zeros([10., 20., 30.]))
B = tf.Variable(tf.ones([10., 20., 30.]))
for n in range(5):
A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
return 1
test_loss()
但是这个returns错误信息:
---> 10 A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
...
ValueError: Sliced assignment is only supported for variables
我明白 'assign' returns 不是 'Variable',因此在下一个循环中传递 'A' 将
找不到 'Variable' 了。
然后我尝试了:
def test_loss():
m = tf.Variable(1)
#n = 3
A = tf.Variable(tf.zeros([10., 20., 30.]))
B = tf.Variable(tf.ones([10., 20., 30.]))
for n in range(5):
A = tf.Variable(A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5]))
return 1
test_loss()
然后我得到:
InvalidArgumentError: Input 'ref' passed float expected ref type while building NodeDef...
关于我可以递归地为 Tensorflow 变量切片赋值有什么想法吗?
以下是使用 tf.Variable
和 assign()
的一些见解。
第一个失败的解决方案
for n in range(5):
A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
当你做 A.assign(B)
时,它实际上 return 是张量(即不是 tf.Variable
)。所以它适用于第一次迭代。从下一次迭代开始,您正在尝试将值分配给 tf.Tensor
,这是不允许的。
失败的第二个解决方案
for n in range(5):
A = tf.Variable(A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5]))
这又是一个非常糟糕的主意,因为您是在循环中创建变量。这样做够了,你就会 运行 内存不足。但这甚至不会 运行,因为你最终陷入了一个奇怪的僵局。您正在尝试使用一些张量创建变量,这些张量将在 Graph 执行时计算。要执行图形,您需要变量。
正确的做法
我能想到的最好的方法是 test_loss
到 return 更新操作,然后让 n
成为 TensorFlow 占位符。在 运行 会话的每次迭代中,您将一个值传递给 n
(这是当前迭代)。
def test_loss(n):
m = tf.Variable(1)
#n = 3
A = tf.Variable(tf.zeros([10., 20., 30.]))
B = tf.Variable(tf.ones([10., 20., 30.]))
update = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
return update
with tf.Session() as sess:
tf_n = tf.placeholder(shape=None, dtype=tf.int32, name='n')
update_op = test_loss(tf_n)
print(type(update_op))
tf.global_variables_initializer().run()
for n in range(5):
print(1)
#print(sess.run(update_op, feed_dict={tf_n: n}))
我想递归地为 Tensorflow (1.15) 变量中的切片赋值。
为了说明,这个有效:
def test_loss():
m = tf.Variable(1)
n = 3
A = tf.Variable(tf.zeros([10., 20., 30.]))
B = tf.Variable(tf.ones([10., 20., 30.]))
A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
return 1
test_loss()
Out: 1
然后我尝试了:
def test_loss():
m = tf.Variable(1)
#n = 3
A = tf.Variable(tf.zeros([10., 20., 30.]))
B = tf.Variable(tf.ones([10., 20., 30.]))
for n in range(5):
A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
return 1
test_loss()
但是这个returns错误信息:
---> 10 A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
...
ValueError: Sliced assignment is only supported for variables
我明白 'assign' returns 不是 'Variable',因此在下一个循环中传递 'A' 将 找不到 'Variable' 了。
然后我尝试了:
def test_loss():
m = tf.Variable(1)
#n = 3
A = tf.Variable(tf.zeros([10., 20., 30.]))
B = tf.Variable(tf.ones([10., 20., 30.]))
for n in range(5):
A = tf.Variable(A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5]))
return 1
test_loss()
然后我得到:
InvalidArgumentError: Input 'ref' passed float expected ref type while building NodeDef...
关于我可以递归地为 Tensorflow 变量切片赋值有什么想法吗?
以下是使用 tf.Variable
和 assign()
的一些见解。
第一个失败的解决方案
for n in range(5):
A = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
当你做 A.assign(B)
时,它实际上 return 是张量(即不是 tf.Variable
)。所以它适用于第一次迭代。从下一次迭代开始,您正在尝试将值分配给 tf.Tensor
,这是不允许的。
失败的第二个解决方案
for n in range(5):
A = tf.Variable(A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5]))
这又是一个非常糟糕的主意,因为您是在循环中创建变量。这样做够了,你就会 运行 内存不足。但这甚至不会 运行,因为你最终陷入了一个奇怪的僵局。您正在尝试使用一些张量创建变量,这些张量将在 Graph 执行时计算。要执行图形,您需要变量。
正确的做法
我能想到的最好的方法是 test_loss
到 return 更新操作,然后让 n
成为 TensorFlow 占位符。在 运行 会话的每次迭代中,您将一个值传递给 n
(这是当前迭代)。
def test_loss(n):
m = tf.Variable(1)
#n = 3
A = tf.Variable(tf.zeros([10., 20., 30.]))
B = tf.Variable(tf.ones([10., 20., 30.]))
update = A[m+1:n+1, 10:12, 20:22].assign(B[m:n, 2:4, 3:5])
return update
with tf.Session() as sess:
tf_n = tf.placeholder(shape=None, dtype=tf.int32, name='n')
update_op = test_loss(tf_n)
print(type(update_op))
tf.global_variables_initializer().run()
for n in range(5):
print(1)
#print(sess.run(update_op, feed_dict={tf_n: n}))