手动获取下一批或使用与 TensorFlow 数据相同的批次 API

Manually Get Next Batch or Use Identical Batch with TensorFlow Data API

我正在尝试使用 tf.Data API 来加速我的代码并防止 GPU 数据匮乏,但有一件事让我无法适应它,那就是使用它的能力多次调用训练操作时的同一批次。

假设我的数据集设置为

dataset = tf.data.TextLineDataset("textfile.txt")
dataset = dataset.shuffle(dataset_size)
dataset = dataset.padded_batch(batch_size, ...)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
x_batch = iterator.get_next()

loss1 = someFunctionOf(x_batch)
loss2 = someOtherFunctionOf(x_batch)
train_op1 = someOptimizerOf(loss1)
train_op2 = someOtherOptimizerOf(loss2)

但现在每当我调用 train_op1 时,都会调用 iterator.get_next(),因此当调用 train_op2 时,我正在对下一批进行训练。

问题,我知道我可以使用 flat_maprepeat(n) 的组合,其中 n 是我想重复相同的次数批处理,但这 n 将取决于我调用的 train_ops 的数量,我必须手动计算。另外,我需要这两个 train_ops 因为它们优化了我图表的不同部分。

感谢您的帮助!

试试下面的代码。它会创建输入和目标的副本,因此希望它们在您切换 optimizer/loss_op 时不会更改。只要您不传递 is_new:True 标志,它们就会在 sess.run 调用之间持续存在。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf


def ds_train(batch_size, num_epochs):  
    ds = (tf.data.Dataset.from_tensor_slices(([1.0,2.0,3.0,4.0,5.0], [-1,-2,-3,-4,-5]))
            .batch(batch_size)
            .repeat(num_epochs)        
            )
    return ds


batch_size = 1
input_size = 1
num_epochs = 2

with tf.variable_scope("dataset"):       
    ds_t = ds_train(batch_size, num_epochs)

with tf.variable_scope("iterator"):
    iterator_t = ds_t.make_initializable_iterator()
    iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
    iterator = tf.data.Iterator.from_string_handle(iterator_handle, 
                                                iterator_t.output_types,
                                                iterator_t.output_shapes)

    def next_item():
        next_elem = iterator.get_next(name="next_element")
        x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]# tf.cast(next_elem[1], tf.int32)
        return x, y        


inputs = tf.Variable(tf.zeros(shape=[batch_size,input_size]), dtype=tf.float32, name="inputs", trainable=False, use_resource=True)
target = tf.Variable(tf.zeros(shape=[batch_size], dtype=tf.int32), dtype=tf.int32, name="target", trainable=False,use_resource=True)
is_new = tf.placeholder_with_default(tf.constant(False), shape=[], name="new_item_flag")

def new_data(batch_size, input_size):
    # run the data layer to generate a new batch
    next_inputs, next_target = next_item()
    next_inputs = tf.reshape(next_inputs, shape=[batch_size, input_size])
    with tf.control_dependencies([tf.assign(inputs, next_inputs), tf.assign(target, next_target)]):
        return tf.identity(inputs), tf.identity(target)

def old_data():
    # just forward the existing batch
    return inputs, target

next_inputs, next_target = next_item()

inputs, target =  tf.cond(is_new, lambda:new_data(batch_size, input_size), old_data)

with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
    handle_t = sess.run(iterator_t.string_handle())
    sess.run(iterator_t.initializer)
    while True:
        try:
            print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
            print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: False}))
            print(sess.run([inputs, target], feed_dict={iterator_handle:handle_t, is_new: True}))
        except tf.errors.OutOfRangeError:
            print("End of training dataset.")
            break