如何填充固定 BATCH_SIZE in tf.data.Dataset?

How to pad to fixed BATCH_SIZE in tf.data.Dataset?

我有一个包含 11 个样本的数据集。而当我选择BATCH_SIZE为2时,下面的代码会出现错误:

dataset = tf.contrib.data.TFRecordDataset(filenames) 
dataset = dataset.map(parser)
if shuffle:
    dataset = dataset.shuffle(buffer_size=128)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(count=1)

问题出在dataset = dataset.batch(batch_size),当Dataset循环到最后一批时,剩下的样本数只有1,请问有没有办法从上次访问的样本中随机抽取一个采样并生成最后一批?

@mining 提出了一个通过填充文件名的解决方案。

另一个解决方案是使用 tf.contrib.data.batch_and_drop_remainder。这将以固定的批量大小对数据进行批处理,并删除最后一个较小的批次。

在您的示例中,有 11 个输入和 2 个批次大小,这将产生 5 个批次,每批次有 2 个元素。

这是文档中的示例:

dataset = tf.data.Dataset.range(11)
batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(2))

您只需在 batch 的调用中设置 drop_remainder=True

dataset = dataset.batch(batch_size, drop_remainder=True)

来自documentation

drop_remainder: (Optional.) A tf.bool scalar tf.Tensor, representing whether the last batch should be dropped in the case its has fewer than batch_size elements; the default behavior is not to drop the smaller batch.