在 Tensorflow 运行时更改 tf.dataset 源代码
Change tf.dataset source at runtime in Tensorflow
我有两个 tf.datasets,一个用于验证,一个用于训练。
我时不时地想切换数据源,以便我可以 运行 验证并检查它的一些准确性度量。
This blog 建议使用占位符并将普通的 numpy 数组提供给它。但这会破坏整个效率目的;
正如 tf.data API api 指南所说:
Warning: "Feeding" is the least efficient way to feed data into a TensorFlow program and should only be used for small experiments and debugging.
所以,这是我想要实现的概念示例:
# Load the datasets from tfrecord files:
val_dataset = tf.data.TFRecordDataset([val_recordfile])
train_dataset = tf.data.TFRecordDataset([train_recordfile])
## Batch size end shuffeling etc. here ##
iterator_tr = train_dataset.make_initializable_iterator()
iterator_val = val_dataset.make_initializable_iterator()
###############################################
## This is the magic: ##
it_op=tf.iterator_placeholder()
## tf.iterator_placeholder does not exist! ##
## and demonstrates my needs ##
###############################################
X, Y = it_op.get_next()
predictions=model(X)
train_op=train_and_minimize(X,Y)
acc_op=get_accuracy(Y,predictions)
with tf.Session() as sess:
# Initialize iterator here
accuracy_tr,_=sess.run([acc_op,train_op], feed_dict={it_op: iterator_tr})
accuracy_val=sess.run(acc_op, feed_dict={it_op: iterator_val})
它不是当然必须以这种方式完成!
我更喜欢 pytonic/ideomatic 张量流方式,但任何不需要提供原始数据的方式对我来说都很棒!
事实证明,我的建议并没有那么奏效。它实际上出现在 the Tensorflow's guide on datasets:
# Load the datasets in some form of tf.Dataset
tr_dataset=get_dataset(TRAINING)
val_dataset=get_dataset(VALIDATION)
# batching etc..
train_iterator = tr_dataset.make_initializable_iterator()
val_iterator = val_dataset.make_initializable_iterator()
# Make iterator handle that takes a string identifier
iterator_handle = tf.placeholder(tf.string, shape=[])
iterator=tf.data.Iterator.from_string_handle(iterator_handle, train_iterator.output_types,output_shapes=train_iterator.output_shapes)
with tf.Session() as sess:
# Create string handlers for the iterators
train_iterator_handle = sess.run(train_iterator.string_handle())
val_iterator_handle = sess.run(val_iterator.string_handle())
# Now initialize iterators
sess.run(train_iterator.initializer, feed_dict={iterator_handle: train_iterator_handle})
sess.run(val_iterator.initializer, feed_dict={iterator_handle: val_iterator_handle})
我有两个 tf.datasets,一个用于验证,一个用于训练。
我时不时地想切换数据源,以便我可以 运行 验证并检查它的一些准确性度量。
This blog 建议使用占位符并将普通的 numpy 数组提供给它。但这会破坏整个效率目的; 正如 tf.data API api 指南所说:
Warning: "Feeding" is the least efficient way to feed data into a TensorFlow program and should only be used for small experiments and debugging.
所以,这是我想要实现的概念示例:
# Load the datasets from tfrecord files:
val_dataset = tf.data.TFRecordDataset([val_recordfile])
train_dataset = tf.data.TFRecordDataset([train_recordfile])
## Batch size end shuffeling etc. here ##
iterator_tr = train_dataset.make_initializable_iterator()
iterator_val = val_dataset.make_initializable_iterator()
###############################################
## This is the magic: ##
it_op=tf.iterator_placeholder()
## tf.iterator_placeholder does not exist! ##
## and demonstrates my needs ##
###############################################
X, Y = it_op.get_next()
predictions=model(X)
train_op=train_and_minimize(X,Y)
acc_op=get_accuracy(Y,predictions)
with tf.Session() as sess:
# Initialize iterator here
accuracy_tr,_=sess.run([acc_op,train_op], feed_dict={it_op: iterator_tr})
accuracy_val=sess.run(acc_op, feed_dict={it_op: iterator_val})
它不是当然必须以这种方式完成!
我更喜欢 pytonic/ideomatic 张量流方式,但任何不需要提供原始数据的方式对我来说都很棒!
事实证明,我的建议并没有那么奏效。它实际上出现在 the Tensorflow's guide on datasets:
# Load the datasets in some form of tf.Dataset
tr_dataset=get_dataset(TRAINING)
val_dataset=get_dataset(VALIDATION)
# batching etc..
train_iterator = tr_dataset.make_initializable_iterator()
val_iterator = val_dataset.make_initializable_iterator()
# Make iterator handle that takes a string identifier
iterator_handle = tf.placeholder(tf.string, shape=[])
iterator=tf.data.Iterator.from_string_handle(iterator_handle, train_iterator.output_types,output_shapes=train_iterator.output_shapes)
with tf.Session() as sess:
# Create string handlers for the iterators
train_iterator_handle = sess.run(train_iterator.string_handle())
val_iterator_handle = sess.run(val_iterator.string_handle())
# Now initialize iterators
sess.run(train_iterator.initializer, feed_dict={iterator_handle: train_iterator_handle})
sess.run(val_iterator.initializer, feed_dict={iterator_handle: val_iterator_handle})