更改 batch()、shuffle() 和 repeat() 顺序时的输出差异

Output differences when changing order of batch(), shuffle() and repeat()

我创建了一个 tensorflow 数据集,使其可重复,对其进行洗牌,将其分成批次,并构建了一个迭代器来获取下一批次。但是当我这样做时,有时元素是重复的(在批次内和批次之间),尤其是对于小数据集。为什么?

您必须先洗牌,然后再重复!

如以下两个代码所示,洗牌和重复的顺序很重要。

最差排序:

import tensorflow as tf

ds = tf.data.Dataset.range(10)
ds = ds.batch(2)
ds = ds.repeat()
ds = ds.shuffle(100000)
iterator   = ds.make_one_shot_iterator()
next_batch = iterator.get_next()

with tf.Session() as sess:
    for i in range(15):
        if i % (10//2) == 0:
            print("------------")
        print("{:02d}:".format(i), next_batch.eval())

输出:

------------
00: [6 7]
01: [2 3]
02: [6 7]
03: [0 1]
04: [8 9]
------------
05: [6 7]
06: [4 5]
07: [6 7]
08: [4 5]
09: [0 1]
------------
10: [2 3]
11: [0 1]
12: [0 1]
13: [2 3]
14: [4 5]

顺序错误:

import tensorflow as tf

ds = tf.data.Dataset.range(10)
ds = ds.batch(2)
ds = ds.shuffle(100000)
ds = ds.repeat()
iterator   = ds.make_one_shot_iterator()
next_batch = iterator.get_next()

with tf.Session() as sess:
    for i in range(15):
        if i % (10//2) == 0:
            print("------------")
        print("{:02d}:".format(i), next_batch.eval())

输出:

------------
00: [4 5]
01: [6 7]
02: [8 9]
03: [0 1]
04: [2 3]
------------
05: [0 1]
06: [4 5]
07: [8 9]
08: [2 3]
09: [6 7]
------------
10: [0 1]
11: [4 5]
12: [8 9]
13: [2 3]
14: [6 7]

最佳排序:

受 GPhilo 回答的启发,批处理的顺序也很重要。对于每个时期不同的批次,必须先洗牌,然后重复,最后批次。从输出中可以看出,所有批次都是唯一的,与其他批次不同。

import tensorflow as tf

ds = tf.data.Dataset.range(10)

ds = ds.shuffle(100000)
ds = ds.repeat()
ds = ds.batch(2)

iterator   = ds.make_one_shot_iterator()
next_batch = iterator.get_next()

with tf.Session() as sess:
    for i in range(15):
        if i % (10//2) == 0:
            print("------------")
        print("{:02d}:".format(i), next_batch.eval())

输出:

------------
00: [2 5]
01: [1 8]
02: [9 6]
03: [3 4]
04: [7 0]
------------
05: [4 3]
06: [0 2]
07: [1 9]
08: [6 5]
09: [8 7]
------------
10: [7 3]
11: [5 9]
12: [4 1]
13: [8 6]
14: [0 2]

与您自己的回答中所述不同,不,洗牌然后重复不会解决您的问题

问题的关键根源是你批处理,然后shuffle/repeat。这样,批次中的项目将始终取自输入数据集中的连续样本。 批处理应该是您在输入管道中执行的最后操作之一

稍微扩展问题。

现在, 随机播放、重复播放和批处理的顺序有所不同,但这不是您的想法。引用自 input pipeline performance guide:

If the repeat transformation is applied before the shuffle transformation, then the epoch boundaries are blurred. That is, certain elements can be repeated before other elements appear even once. On the other hand, if the shuffle transformation is applied before the repeat transformation, then performance might slow down at the beginning of each epoch related to initialization of the internal state of the shuffle transformation. In other words, the former (repeat before shuffle) provides better performance, while the latter (shuffle before repeat) provides stronger ordering guarantees.

回顾

  • 重复,然后洗牌:你失去了所有样本在一个时期内处理的保证。
  • 随机播放,然后重复:保证在下一次重复开始之前处理所有样本,但您会(稍微)损失性能。

无论您选择哪个,在批处理之前执行此操作。

例如,如果您想要与 Keras 的 .fit() 函数相同的行为,您可以使用:

dataset = dataset.shuffle(10000, reshuffle_each_iteration=True)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.repeat(EPOCHS)

这将以与 .fit(epochs=EPOCHS, batch_size=BATCH_SIZE, shuffle=True) 相同的方式遍历数据集。一个简单的示例(仅出于可读性而启用急切执行,图形模式下的行为相同):

import numpy as np
import tensorflow as tf
tf.enable_eager_execution()

NUM_SAMPLES = 7
BATCH_SIZE = 3
EPOCHS = 2

# Create the dataset
x = np.array([[2 * i, 2 * i + 1] for i in range(NUM_SAMPLES)])
dataset = tf.data.Dataset.from_tensor_slices(x)

# Shuffle, batch and repeat the dataset
dataset = dataset.shuffle(10000, reshuffle_each_iteration=True)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.repeat(EPOCHS)

# Iterate through the dataset
iterator = dataset.make_one_shot_iterator()
for batch in dataset:
    print(batch.numpy(), end='\n\n')

打印

[[ 8  9]
 [12 13]
 [10 11]]

[[0 1]
 [2 3]
 [4 5]]

[[6 7]]

[[ 4  5]
 [10 11]
 [12 13]]

[[6 7]
 [0 1]
 [2 3]]

[[8 9]]

你可以看到,即使 .batch().shuffle() 之后被调用 ,每个 epoch 的批次仍然不同。这就是为什么我们需要使用 reshuffle_each_iteration=True。如果我们不在每次迭代时重新洗牌,我们将在每个时期得到相同的批次:

[[12 13]
 [ 4  5]
 [10 11]]

[[6 7]
 [8 9]
 [0 1]]

[[2 3]]

[[12 13]
 [ 4  5]
 [10 11]]

[[6 7]
 [8 9]
 [0 1]]

[[2 3]]

这在训练小型数据集时可能是有害的。