Tensorflow,如何连接具有不同批量大小的多个数据集

Tensorflow, how to concatenate multiple datasets with varying batch sizes

假设我有:

我想从两个数据集中获取批次并将它们连接起来,以便我得到大小为 3 的批次,其中:

如果某些数据集先被清空,我还想阅读最后一批。 在这种情况下,我会得到 [5, 5, 4], [5, 5, 4], [5] 作为我的最终结果。

我该怎么做? 我在这里看到了答案:Tensorflow how to generate unbalanced combined data sets

这是一个很好的尝试,但如果其中一个数据集先于其他数据集被清空,它就不起作用(因为当您尝试从获取的数据集中获取元素时,tf.errors.OutOfRangeError 会先发制人地输出先清空,我没有得到最后一批)。因此我只得到 [5, 5, 4], [5, 5, 4]

我想到了使用 tf.contrib.data.choose_from_datasets:

ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5]).batch(2)
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4]).batch(1)
choice_dataset = [1, 2, 1, 2, 1]
ds = tf.contrib.data.choose_from_datasets([ds1, ds2], choice_dataset)
ds = ds.apply(tf.contrib.data.unbatch())
ds = ds.batch(3, drop_remainder=False)

这种工作但是比较不优雅(有unbatch和batch);另外,我真的无法很好地控制批次中的确切内容。 (例如,如果 ds1 是 [7] * 7,批量大小为 2,而 ds2 是 [2, 2],批量大小为 1,我会得到 [7, 7, 1], [7, 7, 1], [7 , 7, 7]。但是如果我真的想要 [7, 7, 1], [7, 7, 1], [7, 7], [7] 怎么办?即保持每个数据集中的元素数量固定.

还有其他更好的解决方案吗?

我的另一个想法是尝试使用 tf.data.Dataset.flat_map:

ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5])
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4])
batch_sizes = [2, 1]
def concat(*inputs):
  concat = partial(functools.reduce, lambda x, y: x.concatenate(y))
  datasets = [tf.data.Dataset.from_tensors(input) for input in inputs]
  datasets = [dataset.batch(batch_size) for batch_size, dataset in zip(batch_sizes, datasets)]
  return concat(datasets)
dataset = (tf.data.Dataset
           .zip((ds1, ds2))
           .flat_map(_concat_and_batch)
           .batch(sum(batch_sizes)))

但是好像不行..

这是一个解决方案。虽然有一些问题,但希望能满足您的需求。

想法如下:您对两个数据集分别进行批处理,将它们压缩在一起,然后执行映射函数将每个压缩的元组合并为一个批次(到目前为止,这与 and this 个答案。)

如您所见,问题在于压缩仅适用于长度相同的两个数据集。否则,一个数据集先于另一个数据集被消费,其余未消费的元素不被使用。

我(有点古怪)的解决方案是将两个数据集连接到另一个无限虚拟数据集。该虚拟数据集仅包含您知道不会出现在真实数据集中的值。这消除了压缩问题。但是,您需要摆脱所有虚拟元素。这可以通过过滤和映射轻松完成。

import tensorflow as tf

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

# we assume that this value will never occur in `ds1` and `ds2`:
UNUSED_VALUE = -1 

# an infinite dummy dataset:
dummy_ds = tf.data.Dataset.from_tensors(UNUSED_VALUE).repeat() 

# make `ds1` and `ds2` infinite:
ds1 = ds1.concatenate(dummy_ds)
ds2 = ds2.concatenate(dummy_ds)

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

# this is the solution mentioned in the links above
ds = tf.data.Dataset.zip((ds1,ds2))
ds = ds.map(lambda x1, x2: tf.concat((x1,x2),0))

# filter the infinite dummy tail:
ds = ds.filter(lambda x: tf.reduce_any(tf.not_equal(x,UNUSED_VALUE)))

# filter from batches the dummy elements:
ds = ds.map(lambda x: tf.boolean_mask(x,tf.not_equal(x,UNUSED_VALUE)))

这个解决方案有两个主要问题:

(1) 我们需要 UNUSED_VALUE 的值,我们确定它不会出现在数据集中。我怀疑有一个解决方法,可能是通过使虚拟数据集包含空张量(而不是具有常量值的张量),但我还不知道该怎么做。

(2)虽然这个数据集的元素个数是有限的,但是下面的循环永远不会终止:

iter = ds.make_one_shot_iterator()
batch = iter.get_next()
sess = tf.Session()
while True:
    print(sess.run(batch))

原因是迭代器一直在过滤掉虚拟示例,不知道什么时候停止。这可以通过将上面的 repeat() 调用更改为 repeat(n) 来解决,其中 n 是一个您知道比两个数据集的长度差长的数字。

如果您不介意 运行 在构建新数据集期间进行会话,您可以执行以下操作:

import tensorflow as tf
import numpy as np

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()

batch1 = iter1.get_next()
batch2 = iter2.get_next()

sess = tf.Session()

# define a generator that will sess.run both datasets, and will return the concatenation of both
def GetBatch():
    while True:
        try:
            b1 = sess.run(batch1)
        except tf.errors.OutOfRangeError:
            b1 = None
        try:
            b2 = sess.run(batch2)
        except tf.errors.OutOfRangeError:
            b2 = None
        if (b1 is None) and (b2 is None):
            break
        elif b1 is None:
            yield b2
        elif b2 is None:
            yield b1
        else:
            yield np.concatenate((b1,b2))

# create a dataset from the above generator
ds = tf.data.Dataset.from_generator(GetBatch,tf.int32)

注意上面的session可以是hidden\encapsulated如果你愿意(比如在一个函数里面),例如:

iter = ds.make_one_shot_iterator()
batch = iter.get_next()

sess2 = tf.Session()

while True:
    print(sess2.run(batch))

这里有一个解决方案,需要你使用一个"control input",来选择使用哪个批次,你根据首先使用哪个数据集来决定。这可以使用抛出的异常来检测。

为了解释这个解决方案,我先给出一个无效的尝试。

尝试的解决方案 #1

import tensorflow as tf

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()

batch1 = iter1.get_next(name='batch1')
batch2 = iter2.get_next(name='batch2')
batch12 = tf.concat((batch1, batch2), 0)

# this is a "control" placeholder. Its value determines whether to use `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)

batch = tf.cond(
               tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
                       lambda:batch12,
        lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
                       lambda:batch1,
        lambda:batch2)) # else, use `batch2`

sess = tf.Session()

which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
    try:
        print(sess.run(batch,feed_dict={which_batch:which}))
    except tf.errors.OutOfRangeError as e:
        # use the error to detect which dataset was consumed, and update `which` accordingly
        if which==0:
            if 'batch2' in e.op.name:
                which = 1
            else:
                which = 2
        else:
            break

此解决方案不起作用,因为对于 which_batch 的任何值,tf.cond() 命令将评估其分支的所有前导(请参阅 )。因此,即使 which_batch 的值为 1,也会计算 batch2 并抛出 OutOfRangeError

尝试的解决方案 #2

可以通过将 batch1batch2batch12 的定义移动到函数中来解决此问题。

import tensorflow as tf

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()

def get_batch1():
    batch1 = iter1.get_next(name='batch1')
    return batch1

def get_batch2():
    batch2 = iter2.get_next(name='batch2')
    return batch2

def get_batch12():
    batch1 = iter1.get_next(name='batch1_')
    batch2 = iter2.get_next(name='batch2_')
    batch12 = tf.concat((batch1, batch2), 0)
    return batch12

# this is a "control" placeholder. It's value determines whether to ues `batch1`,`batch2` or `batch12`
which_batch = tf.placeholder(tf.int32)

batch = tf.cond(
               tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
                       get_batch12,
        lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1`
                       get_batch1,
        get_batch2)) # elif `which_batch`==2, use `batch2`

sess = tf.Session()

which = 0 # this value will be fed into the control placeholder `which_batch`
while True:
    try:
        print(sess.run(batch,feed_dict={which_batch:which}))
    except tf.errors.OutOfRangeError as e:
        # use the error to detect which dataset was consumed, and update `which` accordingly
        if which==0:
            if 'batch2' in e.op.name:
                which = 1
            else:
                which = 2
        else:
            break

但是,这也不起作用。原因是,在 batch12 形成并消耗数据集 ds2 的步骤中,我们从数据集 ds1 和 "dropped" 中获取了批次,但没有使用它。

解决方案

我们需要一种机制来确保我们不会 "drop" 在其他数据集被消耗的情况下进行任何批处理。我们可以通过定义一个变量来做到这一点,该变量将分配给当前批次的 ds1,但 仅在尝试获取 batch12 之前立即分配。否则,此变量将保留其先前的值。那么,如果batch12由于ds1被消耗而失败,那么这次赋值就会失败,batch2没有被丢弃,我们可以在下次使用它。否则,如果 batch12 由于 ds2 被消耗而失败,那么我们在我们定义的变量中有一个 batch1 的备份,使用这个备份后我们可以继续使用 batch1.

import tensorflow as tf

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()

# this variable will store a backup of `batch1`, in case it is dropped
batch1_backup = tf.Variable(0, trainable=False, validate_shape=False)

def get_batch12():
    batch1 = iter1.get_next(name='batch1')

    # form the combined batch `batch12` only after backing-up `batch1`
    with tf.control_dependencies([tf.assign(batch1_backup, batch1, validate_shape=False)]):
        batch2 = iter2.get_next(name='batch2')
        batch12 = tf.concat((batch1, batch2), 0)
    return batch12

def get_batch1():
    batch1 = iter1.get_next()
    return batch1

def get_batch2():
    batch2 = iter2.get_next()
    return batch2

# this is a "control" placeholder. Its value determines whether to use `batch12`, `batch1_backup`, `batch1`, or `batch2`
which_batch = tf.Variable(0,trainable=False)

batch = tf.cond(
               tf.equal(which_batch,0), # if `which_batch`==0, use `batch12`
                       get_batch12,
        lambda:tf.cond(tf.equal(which_batch,1), # elif `which_batch`==1, use `batch1_backup`
                       lambda:batch1_backup,
        lambda:tf.cond(tf.equal(which_batch,2), # elif `which_batch`==2, use `batch1`
                       get_batch1,
       get_batch2))) # else, use `batch2`

sess = tf.Session()
sess.run(tf.global_variables_initializer())

which = 0  # this value will be fed into the control placeholder
while True:
    try:
        print(sess.run(batch,feed_dict={which_batch:which}))

        # if just used `batch1_backup`, proceed with `batch1`
        if which==1:
            which = 2
    except tf.errors.OutOfRangeError as e:
        # use the error to detect which dataset was consumed, and update `which` accordingly
        if which == 0:
            if 'batch2' in e.op.name:
                which = 1
            else:
                which = 3
        else:
            break