在 Tensorflow 运行时更改 tf.dataset 源代码

Change tf.dataset source at runtime in Tensorflow

我有两个 tf.datasets,一个用于验证,一个用于训练。

我时不时地想切换数据源,以便我可以 运行 验证并检查它的一些准确性度量。

This blog 建议使用占位符并将普通的 numpy 数组提供给它。但这会破坏整个效率目的; 正如 tf.data API api 指南所说:

Warning: "Feeding" is the least efficient way to feed data into a TensorFlow program and should only be used for small experiments and debugging.

所以,这是我想要实现的概念示例:

# Load the datasets from tfrecord files:
val_dataset = tf.data.TFRecordDataset([val_recordfile])
train_dataset = tf.data.TFRecordDataset([train_recordfile])

##  Batch size end shuffeling etc. here  ##

iterator_tr = train_dataset.make_initializable_iterator()
iterator_val = val_dataset.make_initializable_iterator()

###############################################
##  This is the magic:                       ##
it_op=tf.iterator_placeholder()
##  tf.iterator_placeholder does not exist!  ##
##  and demonstrates my needs                ##
###############################################
X, Y = it_op.get_next()
predictions=model(X)
train_op=train_and_minimize(X,Y)
acc_op=get_accuracy(Y,predictions)

with tf.Session() as sess:
    # Initialize iterator here
    accuracy_tr,_=sess.run([acc_op,train_op], feed_dict={it_op: iterator_tr})
    accuracy_val=sess.run(acc_op, feed_dict={it_op: iterator_val})

不是当然必须以这种方式完成!

我更喜欢 pytonic/ideomatic 张量流方式,但任何不需要提供原始数据的方式对我来说都很棒!

事实证明,我的建议并没有那么奏效。它实际上出现在 the Tensorflow's guide on datasets:

# Load the datasets in some form of tf.Dataset
tr_dataset=get_dataset(TRAINING)
val_dataset=get_dataset(VALIDATION)
# batching etc..
train_iterator = tr_dataset.make_initializable_iterator()
val_iterator = val_dataset.make_initializable_iterator()
# Make iterator handle that takes a string identifier
iterator_handle = tf.placeholder(tf.string, shape=[])
iterator=tf.data.Iterator.from_string_handle(iterator_handle, train_iterator.output_types,output_shapes=train_iterator.output_shapes)
with tf.Session() as sess:
    # Create string handlers for the iterators
    train_iterator_handle = sess.run(train_iterator.string_handle())
    val_iterator_handle = sess.run(val_iterator.string_handle())
    # Now initialize iterators
    sess.run(train_iterator.initializer, feed_dict={iterator_handle: train_iterator_handle})
    sess.run(val_iterator.initializer, feed_dict={iterator_handle: val_iterator_handle})