Tensorflow dataset.shuffle 似乎没有 repeat()

Tensorflow dataset.shuffle seems not shuffle without repeat()

我的代码与 tensorflow 2.0 tutorial 有相似的模式。 我希望我的数据集对象在每个时期都重新洗牌。

dataset = tf.data.Dataset.from_tensor_slices(['a','b','c','d'])
dataset = dataset.shuffle(100)

for epoch in range(10):
    for d in dataset:
        print(d)

结果:

tf.Tensor(b'c', shape=(), dtype=string)
tf.Tensor(b'a', shape=(), dtype=string)
tf.Tensor(b'b', shape=(), dtype=string)
tf.Tensor(b'd', shape=(), dtype=string)
tf.Tensor(b'c', shape=(), dtype=string)
tf.Tensor(b'a', shape=(), dtype=string)
tf.Tensor(b'b', shape=(), dtype=string)
tf.Tensor(b'd', shape=(), dtype=string)
...

数据集似乎没有为每个时期打乱顺序。 我应该为每个时期调用 .shuffle() 吗?

是的,您应该在内循环中调用 .shuffle。此外,当相当于 Python 语句的纯 tf.* 方法可用时,最好不要混合 python 代码和 TensorFlow 代码。

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices(["a", "b", "c", "d"])
# dataset = dataset.shuffle(2)


@tf.function
def loop():
    for epoch in tf.range(10):
        for d in dataset.shuffle(2):
            tf.print(d)


loop()

循环调用每次都会产生不同的值(tf.print 打印 tf.Tensor 的内容,不同于打印对象的 print