TensorFlow 中的索引操作
Index operation in TensorFlow
我在做一些数据的批量标注时,有一个变量用来记录所有的计算结果:
p_all = tf.Variable(tf.zeros([batch_num, batch_size]), name = "probability");
在计算中,我有一个循环来处理每个批次:
for i in range(batch_num):
feed = {x: testDS.test.next_batch(batch_size)}
sess.run(p_each_batch, feed_dict=feed)
如何将 p_each_bach
的值复制到 p_all
中?
为了更清楚,我想要这样的东西:
... ...
p_all[batch_index,:] = p_each_batch
for i in range(batch_num):
feed = {x: testDS.test.next_batch(batch_size), batch_index: i}
sess.run(p_all, feed_dict=feed)
我怎样才能使这些代码真正起作用?
因为 p_all
是一个 tf.Variable
, you can use the tf.scatter_update()
操作来更新每个批次中的单个行:
# Equivalent to `p_all[batch_index, :] = p_each_batch`
update_op = tf.scatter_update(p_all,
tf.expand_dims(batch_index, 0),
tf.expand_dims(p_each_batch, 0))
for i in range(batch_num):
feed = {x: testDS.test.next_batch(batch_size), batch_index: i}
sess.run(update_op, feed_dict=feed)
我在做一些数据的批量标注时,有一个变量用来记录所有的计算结果:
p_all = tf.Variable(tf.zeros([batch_num, batch_size]), name = "probability");
在计算中,我有一个循环来处理每个批次:
for i in range(batch_num):
feed = {x: testDS.test.next_batch(batch_size)}
sess.run(p_each_batch, feed_dict=feed)
如何将 p_each_bach
的值复制到 p_all
中?
为了更清楚,我想要这样的东西:
... ...
p_all[batch_index,:] = p_each_batch
for i in range(batch_num):
feed = {x: testDS.test.next_batch(batch_size), batch_index: i}
sess.run(p_all, feed_dict=feed)
我怎样才能使这些代码真正起作用?
因为 p_all
是一个 tf.Variable
, you can use the tf.scatter_update()
操作来更新每个批次中的单个行:
# Equivalent to `p_all[batch_index, :] = p_each_batch`
update_op = tf.scatter_update(p_all,
tf.expand_dims(batch_index, 0),
tf.expand_dims(p_each_batch, 0))
for i in range(batch_num):
feed = {x: testDS.test.next_batch(batch_size), batch_index: i}
sess.run(update_op, feed_dict=feed)