如何为三元组损失制作数据集

How to make Dataset for triplet loss

我正在尝试制作可提供多批 TFRecords 的数据集,其中一批将有来自一个 class 的 2 个随机记录,其余来自另一个随机 classes.

批次的数据集,其中每个 class 有 2 个随机记录适合该批次。

我尝试用 tf.data.Dataset.from_generatortf.data.experimental.choose_from_datasets 来做到这一点,但没有成功。您知道如何执行此操作吗?

编辑: 今天我想我实现了第二个变体。这是我测试它的代码。

def input_fn():
  partial1 = tf.data.Dataset.from_tensor_slices(tf.range(0, 10)).repeat().shuffle(2)
  partial2 = tf.data.Dataset.from_tensor_slices(tf.range(20, 30)).repeat().shuffle(2)
  partial3 = tf.data.Dataset.from_tensor_slices(tf.range(60, 70)).repeat().shuffle(2)
  l = [partial1, partial2, partial3]

  def gen(x):
    return tf.data.Dataset.range(x,x+1).repeat(2)

  dataset = tf.data.Dataset.range(3).flat_map(gen).repeat(10)

  choice = tf.data.experimental.choose_from_datasets(l, dataset).batch(4)
  return choice

评估时 returns

[ 0  2 21 22]
[60 61  1  4]
[20 23 62 63]
[ 3  5 24 25]
[64 66  6  7]
[26 27 65 68]
[ 8  0 28 29]
[67 69  9  2]
[20 22 60 62]
[ 3  1 23 24]
[63 61  4  6]
[25 26 65 64]
[ 7  5 27 28]
[67 66  9  8]
[21 20 69 68]

好的,我明白了。数据集生成成功,数据随机性似乎不错。这不是三元组损失的理想解决方案,因为三元组是随机的而不是半硬的。

def input_fn(self, params):
    batch_size = params['batch_size']

    assert self.data_dir, 'data_dir is required'
    shuffle = self.is_training

    dirs = list(map(lambda x: os.path.join(x, 'train-*' if self.is_training else 'validation-*')), self.dirs)

    def prefetch_dataset(filename): 
      dataset = tf.data.TFRecordDataset( 
          filename, buffer_size=FLAGS.prefetch_dataset_buffer_size)
      return dataset

    datasets = []
    for glob in dirs:
      dataset = tf.data.Dataset.list_files(glob)
      dataset = dataset.apply( 
        tf.contrib.data.parallel_interleave( 
            prefetch_dataset, 
            cycle_length=FLAGS.num_files_infeed, 
            sloppy=True)) # if order is important 
      dataset = dataset.shuffle(batch_size, None, True).repeat().prefetch(batch_size)
      datasets.append(dataset)

    def gen(x):
      return tf.data.Dataset.range(x,x+1).repeat(2)

    choice = tf.data.Dataset.range(len(datasets)).repeat().flat_map(gen)

    dataset = tf.data.experimental.choose_from_datasets(datasets, choice).map( # apply function to each element of the dataset in parallel
        self.dataset_parser, num_parallel_calls=FLAGS.num_parallel_calls)

    dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(8)

    return dataset

在 TF 2.0 中,现在可以使用 dataset.interleave 读取差异 class 的 tfrecords,并使用 dataset.batch 生成三元组对 :

h = FcaeRecHelper('data/ms1m_img_ann.npy', [112, 112], 128, use_softmax=False)
len(h.train_list)
img_shape = list(h.in_hw) + [3]

is_augment = True
is_normlize = False

def parser(stream: bytes):
    # parser tfrecords
    examples: dict = tf.io.parse_single_example(
        stream,
        {'img': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64)})
    return tf.image.decode_jpeg(examples['img'], 3), examples['label']

def pair_parser(raw_imgs, labels):
    # imgs do same augment ~
    if is_augment:
        raw_imgs, _ = h.augment_img(raw_imgs, None)
    # normlize image
    if is_normlize:
        imgs: tf.Tensor = h.normlize_img(raw_imgs)
    else:
        imgs = tf.cast(raw_imgs, tf.float32)

    imgs.set_shape([4] + img_shape)
    labels.set_shape([4, ])
    # Note y_true shape will be [batch,3]
    return (imgs[0], imgs[1], imgs[2]), (labels[:3])

batch_size = 1
# h.train_list : ['a.tfrecords','b.tfrecords','c.tfrecords',...]
ds = (tf.data.Dataset.from_tensor_slices(h.train_list)
        .interleave(lambda x: tf.data.TFRecordDataset(x)
                    .shuffle(100)
                    .repeat(), cycle_length=-1,
                    # block_length = 2 is important
                    block_length=2,
                    num_parallel_calls=-1)
        .map(parser, -1)
        .batch(4, True)
        .map(pair_parser, -1)
        .batch(batch_size, True))

iters = iter(ds)
for i in range(20):
    imgs, labels = next(iters)
    fig, axs = plt.subplots(1, 3)
    axs[0].imshow(imgs[0].numpy().astype('uint8')[0])
    axs[1].imshow(imgs[1].numpy().astype('uint8')[0])
    axs[2].imshow(imgs[2].numpy().astype('uint8')[0])
    plt.show()