Tensorflow,如何连接具有不同批量大小的多个数据集
Tensorflow, how to concatenate multiple datasets with varying batch sizes
假设我有:
- 数据集 1,数据为 [5, 5, 5, 5, 5]
- 包含数据 [4, 4] 的数据集 2
我想从两个数据集中获取批次并将它们连接起来,以便我得到大小为 3 的批次,其中:
- 我读取了批量大小为 2 的数据集 1
- 我读取了批量大小为 1 的数据集 2。
如果某些数据集先被清空,我还想阅读最后一批。
在这种情况下,我会得到 [5, 5, 4], [5, 5, 4], [5] 作为我的最终结果。
我该怎么做?
我在这里看到了答案:Tensorflow how to generate unbalanced combined data sets
这是一个很好的尝试,但如果其中一个数据集先于其他数据集被清空,它就不起作用(因为当您尝试从获取的数据集中获取元素时,tf.errors.OutOfRangeError
会先发制人地输出先清空,我没有得到最后一批)。因此我只得到 [5, 5, 4], [5, 5, 4]
我想到了使用 tf.contrib.data.choose_from_datasets
:
ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5]).batch(2)
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4]).batch(1)
choice_dataset = [1, 2, 1, 2, 1]
ds = tf.contrib.data.choose_from_datasets([ds1, ds2], choice_dataset)
ds = ds.apply(tf.contrib.data.unbatch())
ds = ds.batch(3, drop_remainder=False)
这种工作但是比较不优雅(有unbatch和batch);另外,我真的无法很好地控制批次中的确切内容。 (例如,如果 ds1 是 [7] * 7,批量大小为 2,而 ds2 是 [2, 2],批量大小为 1,我会得到 [7, 7, 1], [7, 7, 1], [7 , 7, 7]。但是如果我真的想要 [7, 7, 1], [7, 7, 1], [7, 7], [7] 怎么办?即保持每个数据集中的元素数量固定.
还有其他更好的解决方案吗?
我的另一个想法是尝试使用 tf.data.Dataset.flat_map
:
ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5])
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4])
batch_sizes = [2, 1]
def concat(*inputs):
concat = partial(functools.reduce, lambda x, y: x.concatenate(y))
datasets = [tf.data.Dataset.from_tensors(input) for input in inputs]
datasets = [dataset.batch(batch_size) for batch_size, dataset in zip(batch_sizes, datasets)]
return concat(datasets)
dataset = (tf.data.Dataset
.zip((ds1, ds2))
.flat_map(_concat_and_batch)
.batch(sum(batch_sizes)))
但是好像不行..
这是一个解决方案。虽然有一些问题,但希望能满足您的需求。
想法如下:您对两个数据集分别进行批处理,将它们压缩在一起,然后执行映射函数将每个压缩的元组合并为一个批次(到目前为止,这与 and this 个答案。)
如您所见,问题在于压缩仅适用于长度相同的两个数据集。否则,一个数据集先于另一个数据集被消费,其余未消费的元素不被使用。
我(有点古怪)的解决方案是将两个数据集连接到另一个无限虚拟数据集。该虚拟数据集仅包含您知道不会出现在真实数据集中的值。这消除了压缩问题。但是,您需要摆脱所有虚拟元素。这可以通过过滤和映射轻松完成。
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
# we assume that this value will never occur in `ds1` and `ds2`:
UNUSED_VALUE = -1
# an infinite dummy dataset:
dummy_ds = tf.data.Dataset.from_tensors(UNUSED_VALUE).repeat()
# make `ds1` and `ds2` infinite:
ds1 = ds1.concatenate(dummy_ds)
ds2 = ds2.concatenate(dummy_ds)
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
# this is the solution mentioned in the links above
ds = tf.data.Dataset.zip((ds1,ds2))
ds = ds.map(lambda x1, x2: tf.concat((x1,x2),0))
# filter the infinite dummy tail:
ds = ds.filter(lambda x: tf.reduce_any(tf.not_equal(x,UNUSED_VALUE)))
# filter from batches the dummy elements:
ds = ds.map(lambda x: tf.boolean_mask(x,tf.not_equal(x,UNUSED_VALUE)))
这个解决方案有两个主要问题:
(1) 我们需要 UNUSED_VALUE
的值,我们确定它不会出现在数据集中。我怀疑有一个解决方法,可能是通过使虚拟数据集包含空张量(而不是具有常量值的张量),但我还不知道该怎么做。
(2)虽然这个数据集的元素个数是有限的,但是下面的循环永远不会终止:
iter = ds.make_one_shot_iterator()
batch = iter.get_next()
sess = tf.Session()
while True:
print(sess.run(batch))
原因是迭代器一直在过滤掉虚拟示例,不知道什么时候停止。这可以通过将上面的 repeat()
调用更改为 repeat(n)
来解决,其中 n
是一个您知道比两个数据集的长度差长的数字。
如果您不介意 运行 在构建新数据集期间进行会话,您可以执行以下操作:
import tensorflow as tf
import numpy as np
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
batch1 = iter1.get_next()
batch2 = iter2.get_next()
sess = tf.Session()
# define a generator that will sess.run both datasets, and will return the concatenation of both
def GetBatch():
while True:
try:
b1 = sess.run(batch1)
except tf.errors.OutOfRangeError:
b1 = None
try:
b2 = sess.run(batch2)
except tf.errors.OutOfRangeError:
b2 = None
if (b1 is None) and (b2 is None):
break
elif b1 is None:
yield b2
elif b2 is None:
yield b1
else:
yield np.concatenate((b1,b2))
# create a dataset from the above generator
ds = tf.data.Dataset.from_generator(GetBatch,tf.int32)
注意上面的session可以是hidden\encapsulated如果你愿意(比如在一个函数里面),例如:
iter = ds.make_one_shot_iterator()
batch = iter.get_next()
sess2 = tf.Session()
while True:
print(sess2.run(batch))
这里有一个解决方案,需要你使用一个"control input",来选择使用哪个批次,你根据首先使用哪个数据集来决定。这可以使用抛出的异常来检测。
为了解释这个解决方案,我先给出一个无效的尝试。
尝试的解决方案 #1
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
batch1 = iter1.get_next(name='batch1')
batch2 = iter2.get_next(name='batch2')
batch12 = tf.concat((batch1, batch2), 0)
# this is a "control" placeholder. Its value determines whether to use `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)
batch = tf.cond(
tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
lambda:batch12,
lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
lambda:batch1,
lambda:batch2)) # else, use `batch2`
sess = tf.Session()
which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
try:
print(sess.run(batch,feed_dict={which_batch:which}))
except tf.errors.OutOfRangeError as e:
# use the error to detect which dataset was consumed, and update `which` accordingly
if which==0:
if 'batch2' in e.op.name:
which = 1
else:
which = 2
else:
break
此解决方案不起作用,因为对于 which_batch
的任何值,tf.cond()
命令将评估其分支的所有前导(请参阅 )。因此,即使 which_batch
的值为 1,也会计算 batch2
并抛出 OutOfRangeError
。
尝试的解决方案 #2
可以通过将 batch1
、batch2
和 batch12
的定义移动到函数中来解决此问题。
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
def get_batch1():
batch1 = iter1.get_next(name='batch1')
return batch1
def get_batch2():
batch2 = iter2.get_next(name='batch2')
return batch2
def get_batch12():
batch1 = iter1.get_next(name='batch1_')
batch2 = iter2.get_next(name='batch2_')
batch12 = tf.concat((batch1, batch2), 0)
return batch12
# this is a "control" placeholder. It's value determines whether to ues `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)
batch = tf.cond(
tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
get_batch12,
lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
get_batch1,
get_batch2)) # elif `which_batch`==2, use `batch2`
sess = tf.Session()
which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
try:
print(sess.run(batch,feed_dict={which_batch:which}))
except tf.errors.OutOfRangeError as e:
# use the error to detect which dataset was consumed, and update `which` accordingly
if which==0:
if 'batch2' in e.op.name:
which = 1
else:
which = 2
else:
break
但是,这也不起作用。原因是,在 batch12
形成并消耗数据集 ds2
的步骤中,我们从数据集 ds1
和 "dropped" 中获取了批次,但没有使用它。
解决方案
我们需要一种机制来确保我们不会 "drop" 在其他数据集被消耗的情况下进行任何批处理。我们可以通过定义一个变量来做到这一点,该变量将分配给当前批次的 ds1
,但 仅在尝试获取 batch12
之前立即分配。否则,此变量将保留其先前的值。那么,如果batch12
由于ds1
被消耗而失败,那么这次赋值就会失败,batch2
没有被丢弃,我们可以在下次使用它。否则,如果 batch12
由于 ds2
被消耗而失败,那么我们在我们定义的变量中有一个 batch1
的备份,使用这个备份后我们可以继续使用 batch1
.
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
# this variable will store a backup of `batch1`, in case it is dropped
batch1_backup = tf.Variable(0, trainable=False, validate_shape=False)
def get_batch12():
batch1 = iter1.get_next(name='batch1')
# form the combined batch `batch12` only after backing-up `batch1`
with tf.control_dependencies([tf.assign(batch1_backup, batch1, validate_shape=False)]):
batch2 = iter2.get_next(name='batch2')
batch12 = tf.concat((batch1, batch2), 0)
return batch12
def get_batch1():
batch1 = iter1.get_next()
return batch1
def get_batch2():
batch2 = iter2.get_next()
return batch2
# this is a "control" placeholder. Its value determines whether to use `batch12`, `batch1_backup`, `batch1`, or `batch2`
which_batch = tf.Variable(0,trainable=False)
batch = tf.cond(
tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
get_batch12,
lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1_backup`
lambda:batch1_backup,
lambda:tf.cond(tf.equal(which_batch,2), # elif `which_batch`==2, use `batch1`
get_batch1,
get_batch2))) # else, use `batch2`
sess = tf.Session()
sess.run(tf.global_variables_initializer())
which = 0 # this value will be fed into the control placeholder
while True:
try:
print(sess.run(batch,feed_dict={which_batch:which}))
# if just used `batch1_backup`, proceed with `batch1`
if which==1:
which = 2
except tf.errors.OutOfRangeError as e:
# use the error to detect which dataset was consumed, and update `which` accordingly
if which == 0:
if 'batch2' in e.op.name:
which = 1
else:
which = 3
else:
break
假设我有:
- 数据集 1,数据为 [5, 5, 5, 5, 5]
- 包含数据 [4, 4] 的数据集 2
我想从两个数据集中获取批次并将它们连接起来,以便我得到大小为 3 的批次,其中:
- 我读取了批量大小为 2 的数据集 1
- 我读取了批量大小为 1 的数据集 2。
如果某些数据集先被清空,我还想阅读最后一批。 在这种情况下,我会得到 [5, 5, 4], [5, 5, 4], [5] 作为我的最终结果。
我该怎么做? 我在这里看到了答案:Tensorflow how to generate unbalanced combined data sets
这是一个很好的尝试,但如果其中一个数据集先于其他数据集被清空,它就不起作用(因为当您尝试从获取的数据集中获取元素时,tf.errors.OutOfRangeError
会先发制人地输出先清空,我没有得到最后一批)。因此我只得到 [5, 5, 4], [5, 5, 4]
我想到了使用 tf.contrib.data.choose_from_datasets
:
ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5]).batch(2)
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4]).batch(1)
choice_dataset = [1, 2, 1, 2, 1]
ds = tf.contrib.data.choose_from_datasets([ds1, ds2], choice_dataset)
ds = ds.apply(tf.contrib.data.unbatch())
ds = ds.batch(3, drop_remainder=False)
这种工作但是比较不优雅(有unbatch和batch);另外,我真的无法很好地控制批次中的确切内容。 (例如,如果 ds1 是 [7] * 7,批量大小为 2,而 ds2 是 [2, 2],批量大小为 1,我会得到 [7, 7, 1], [7, 7, 1], [7 , 7, 7]。但是如果我真的想要 [7, 7, 1], [7, 7, 1], [7, 7], [7] 怎么办?即保持每个数据集中的元素数量固定.
还有其他更好的解决方案吗?
我的另一个想法是尝试使用 tf.data.Dataset.flat_map
:
ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5])
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4])
batch_sizes = [2, 1]
def concat(*inputs):
concat = partial(functools.reduce, lambda x, y: x.concatenate(y))
datasets = [tf.data.Dataset.from_tensors(input) for input in inputs]
datasets = [dataset.batch(batch_size) for batch_size, dataset in zip(batch_sizes, datasets)]
return concat(datasets)
dataset = (tf.data.Dataset
.zip((ds1, ds2))
.flat_map(_concat_and_batch)
.batch(sum(batch_sizes)))
但是好像不行..
这是一个解决方案。虽然有一些问题,但希望能满足您的需求。
想法如下:您对两个数据集分别进行批处理,将它们压缩在一起,然后执行映射函数将每个压缩的元组合并为一个批次(到目前为止,这与
如您所见,问题在于压缩仅适用于长度相同的两个数据集。否则,一个数据集先于另一个数据集被消费,其余未消费的元素不被使用。
我(有点古怪)的解决方案是将两个数据集连接到另一个无限虚拟数据集。该虚拟数据集仅包含您知道不会出现在真实数据集中的值。这消除了压缩问题。但是,您需要摆脱所有虚拟元素。这可以通过过滤和映射轻松完成。
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
# we assume that this value will never occur in `ds1` and `ds2`:
UNUSED_VALUE = -1
# an infinite dummy dataset:
dummy_ds = tf.data.Dataset.from_tensors(UNUSED_VALUE).repeat()
# make `ds1` and `ds2` infinite:
ds1 = ds1.concatenate(dummy_ds)
ds2 = ds2.concatenate(dummy_ds)
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
# this is the solution mentioned in the links above
ds = tf.data.Dataset.zip((ds1,ds2))
ds = ds.map(lambda x1, x2: tf.concat((x1,x2),0))
# filter the infinite dummy tail:
ds = ds.filter(lambda x: tf.reduce_any(tf.not_equal(x,UNUSED_VALUE)))
# filter from batches the dummy elements:
ds = ds.map(lambda x: tf.boolean_mask(x,tf.not_equal(x,UNUSED_VALUE)))
这个解决方案有两个主要问题:
(1) 我们需要 UNUSED_VALUE
的值,我们确定它不会出现在数据集中。我怀疑有一个解决方法,可能是通过使虚拟数据集包含空张量(而不是具有常量值的张量),但我还不知道该怎么做。
(2)虽然这个数据集的元素个数是有限的,但是下面的循环永远不会终止:
iter = ds.make_one_shot_iterator()
batch = iter.get_next()
sess = tf.Session()
while True:
print(sess.run(batch))
原因是迭代器一直在过滤掉虚拟示例,不知道什么时候停止。这可以通过将上面的 repeat()
调用更改为 repeat(n)
来解决,其中 n
是一个您知道比两个数据集的长度差长的数字。
如果您不介意 运行 在构建新数据集期间进行会话,您可以执行以下操作:
import tensorflow as tf
import numpy as np
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
batch1 = iter1.get_next()
batch2 = iter2.get_next()
sess = tf.Session()
# define a generator that will sess.run both datasets, and will return the concatenation of both
def GetBatch():
while True:
try:
b1 = sess.run(batch1)
except tf.errors.OutOfRangeError:
b1 = None
try:
b2 = sess.run(batch2)
except tf.errors.OutOfRangeError:
b2 = None
if (b1 is None) and (b2 is None):
break
elif b1 is None:
yield b2
elif b2 is None:
yield b1
else:
yield np.concatenate((b1,b2))
# create a dataset from the above generator
ds = tf.data.Dataset.from_generator(GetBatch,tf.int32)
注意上面的session可以是hidden\encapsulated如果你愿意(比如在一个函数里面),例如:
iter = ds.make_one_shot_iterator()
batch = iter.get_next()
sess2 = tf.Session()
while True:
print(sess2.run(batch))
这里有一个解决方案,需要你使用一个"control input",来选择使用哪个批次,你根据首先使用哪个数据集来决定。这可以使用抛出的异常来检测。
为了解释这个解决方案,我先给出一个无效的尝试。
尝试的解决方案 #1
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
batch1 = iter1.get_next(name='batch1')
batch2 = iter2.get_next(name='batch2')
batch12 = tf.concat((batch1, batch2), 0)
# this is a "control" placeholder. Its value determines whether to use `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)
batch = tf.cond(
tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
lambda:batch12,
lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
lambda:batch1,
lambda:batch2)) # else, use `batch2`
sess = tf.Session()
which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
try:
print(sess.run(batch,feed_dict={which_batch:which}))
except tf.errors.OutOfRangeError as e:
# use the error to detect which dataset was consumed, and update `which` accordingly
if which==0:
if 'batch2' in e.op.name:
which = 1
else:
which = 2
else:
break
此解决方案不起作用,因为对于 which_batch
的任何值,tf.cond()
命令将评估其分支的所有前导(请参阅 which_batch
的值为 1,也会计算 batch2
并抛出 OutOfRangeError
。
尝试的解决方案 #2
可以通过将 batch1
、batch2
和 batch12
的定义移动到函数中来解决此问题。
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
def get_batch1():
batch1 = iter1.get_next(name='batch1')
return batch1
def get_batch2():
batch2 = iter2.get_next(name='batch2')
return batch2
def get_batch12():
batch1 = iter1.get_next(name='batch1_')
batch2 = iter2.get_next(name='batch2_')
batch12 = tf.concat((batch1, batch2), 0)
return batch12
# this is a "control" placeholder. It's value determines whether to ues `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)
batch = tf.cond(
tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
get_batch12,
lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
get_batch1,
get_batch2)) # elif `which_batch`==2, use `batch2`
sess = tf.Session()
which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
try:
print(sess.run(batch,feed_dict={which_batch:which}))
except tf.errors.OutOfRangeError as e:
# use the error to detect which dataset was consumed, and update `which` accordingly
if which==0:
if 'batch2' in e.op.name:
which = 1
else:
which = 2
else:
break
但是,这也不起作用。原因是,在 batch12
形成并消耗数据集 ds2
的步骤中,我们从数据集 ds1
和 "dropped" 中获取了批次,但没有使用它。
解决方案
我们需要一种机制来确保我们不会 "drop" 在其他数据集被消耗的情况下进行任何批处理。我们可以通过定义一个变量来做到这一点,该变量将分配给当前批次的 ds1
,但 仅在尝试获取 batch12
之前立即分配。否则,此变量将保留其先前的值。那么,如果batch12
由于ds1
被消耗而失败,那么这次赋值就会失败,batch2
没有被丢弃,我们可以在下次使用它。否则,如果 batch12
由于 ds2
被消耗而失败,那么我们在我们定义的变量中有一个 batch1
的备份,使用这个备份后我们可以继续使用 batch1
.
import tensorflow as tf
ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])
ds1 = ds1.batch(2)
ds2 = ds2.batch(1)
iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()
# this variable will store a backup of `batch1`, in case it is dropped
batch1_backup = tf.Variable(0, trainable=False, validate_shape=False)
def get_batch12():
batch1 = iter1.get_next(name='batch1')
# form the combined batch `batch12` only after backing-up `batch1`
with tf.control_dependencies([tf.assign(batch1_backup, batch1, validate_shape=False)]):
batch2 = iter2.get_next(name='batch2')
batch12 = tf.concat((batch1, batch2), 0)
return batch12
def get_batch1():
batch1 = iter1.get_next()
return batch1
def get_batch2():
batch2 = iter2.get_next()
return batch2
# this is a "control" placeholder. Its value determines whether to use `batch12`, `batch1_backup`, `batch1`, or `batch2`
which_batch = tf.Variable(0,trainable=False)
batch = tf.cond(
tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
get_batch12,
lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1_backup`
lambda:batch1_backup,
lambda:tf.cond(tf.equal(which_batch,2), # elif `which_batch`==2, use `batch1`
get_batch1,
get_batch2))) # else, use `batch2`
sess = tf.Session()
sess.run(tf.global_variables_initializer())
which = 0 # this value will be fed into the control placeholder
while True:
try:
print(sess.run(batch,feed_dict={which_batch:which}))
# if just used `batch1_backup`, proceed with `batch1`
if which==1:
which = 2
except tf.errors.OutOfRangeError as e:
# use the error to detect which dataset was consumed, and update `which` accordingly
if which == 0:
if 'batch2' in e.op.name:
which = 1
else:
which = 3
else:
break