Tensorflow 数据集 - 如何在给定生成器为 1 个标签输出 X 个输入的情况下构建批次?

Tensorflow Dataset - How to build batchs given a generator outputting X inputs for 1 label?

精简版

给定一个生成器采样,例如3 输入和 1 标签,我如何定义我的 Tensorflow 数据集管道以获取成批的 K * 3 输入和 K * 1 标签?


加长版

上下文

我正在使用 Triplet 网络,并希望调整我当前的输入管道以使用 Tensorflow 数据集。

在我的例子中,一个批次由 N(例如图像)和 N // 3 标签(假设 N % 3 == 0)组成,每个标签应用于 3 个连续的输入,例如

labels = [compute_label(inputs[3*i], inputs[3*i+1], inputs[3*i+2]) for i in range(N // 3)]

with compute_label(*args) 一个简单的函数,可以用 Tensorflow 操作或基本的 Python.

来实现

为了让事情变得更复杂一点,输入元素必须被 3×3 采样(例如,我们希望 inputs[3*i] 类似于 inputs[3*i+1] 而不同于 inputs[3*i+2]):

for i in range(N // 3):
    inputs[3*i], inputs[3*i+1], inputs[3*i+2] = sample_triplet(i)

问题

针对我的具体情况重新制定较短的问题:

鉴于这两个函数 sample_triplet()compute_label(),我如何使用 Tensorflow 数据集构建我的输入管道,以使用 N 输入和 N // 3标签?

我尝试了 tf.data.Dataset.from_generator()tf.data.Dataset.flat_map() 的几种组合,但无法找到一种方法将批输入从 N // 3 三元组扁平化为 N 样本,并且仅输出 N // 3 个批次标签。

我找到的一个解决方案是 "cheat" 通过在 tf.data.Dataset.from_generator() 内计算我的标签并将每个标签平铺 3 次,以便能够在三元组输入 + 标签上使用 tf.data.Dataset.flat_map()。作为批处理后处理步骤,我然后 "squeezing" N 重复的标签返回 N // 3

当前解决方案示例

import tensorflow as tf
import numpy as np

def sample_triplet():
    # Sampling our elements, here as [class, random_val] elements:
    anchor_class = puller_class = pusher_class = np.random.randint(0, 10)
    while pusher_class == anchor_class:
        # we want the pusher to be of a different class
        pusher_class = np.random.randint(0, 10) 

    anchor = np.array([anchor_class, np.random.randint(0, 5)])
    puller = np.array([puller_class, np.random.randint(0, 5)])
    pusher = np.array([pusher_class, np.random.randint(0, 5)])

    # Stacking the triplets, to then flat_map as a batch:
    triplet_inputs = np.stack((anchor, puller, pusher), axis=0)
    # Returning also the classes to compute the label afterwards:
    triplet_classes = np.stack((anchor_class, puller_class, pusher_class), axis=0)
    return triplet_inputs, triplet_classes

def compute_labels(triplet_classes):
    # Computing the label, e.g. distance between the anchor and pusher classes:
    label = np.abs(triplet_classes[0] - triplet_classes[2])
    return label

def triplet_generator():
    while True:
        triplet = sample_triplet()

        # Current solution: computing the label here too, 
        # stacking it 3 times so that flat_map works,
        # then afterwards removing the duplicates:
        triplet_inputs = triplet[0]
        triplet_label = compute_labels(triplet[1])
        yield triplet_inputs, 
              np.stack((triplet_label, triplet_label, triplet_label), axis=0)

def squeeze_triplet_labels(*batch):
    # Removing the duplicate labels,
    # going from a batch of (N inputs, N labels) to (N inputs, N // 3 labels)
    squeezed_labels = batch[-1][::3]
    new_batch = (*batch[:-1], squeezed_labels)
    return new_batch


batch_size = 30
assert(batch_size % 3 == 0)
sess = tf.InteractiveSession()
train_dataset = (tf.data.Dataset
                 .from_generator(triplet_generator, (tf.int32, tf.float32), ([3, 2], [3]))
                 .flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x))
                 .batch(batch_size))

next_training_batch = train_dataset.make_one_shot_iterator().get_next()
next_proper_training_batch = squeeze_triplet_labels(*next_training_batch)
batch = sess.run(next_proper_training_batch)
print("inputs shape: {} ; label shape: {}".format(batch[0].shape, batch[1].shape))
# >> inputs shape: (30, 2) ; label shape: (10,)

一个简单的解决方案可能是创建 2 个数据集对象,一个用于标签,一个用于数据,然后按 3 个一组对数据进行批处理,并使用 tf.data.interleave 将两个数据集交错在一起,产生结果你要。

如果这不容易做到,那么您可以尝试以下将一个元素映射到多个元素的过程。您将必须创建一批 3 个元素(带有 3 个标签),然后在 map 函数中将其拆分为 3 组数据,每组针对您收到的一个标签。这样做的方法是在下面的 SO 问题中,尽管它比第一个建议更复杂: