如何在 TensorFlow 2.0 中使用 tf.data API 在每个 epoch 打乱数据?

How to shuffle data at each epoch using tf.data API in TensorFlow 2.0?

我正在尝试使用 TensorFlow 2.0 来训练我的模型。 tf.data API 中的新迭代功能非常棒。然而,当我执行下面的代码时,我发现它不像 torch.utils.data.DataLoader 中的迭代特性,它不会在每个 epoch 自动打乱数据。我如何使用 TF2.0 实现这一点?

import numpy as np
import tensorflow as tf
def sample_data():
    ...

data = sample_data()

NUM_EPOCHS = 10
BATCH_SIZE = 128

# Subsample the data
mask = range(int(data.shape[0]*0.8), data.shape[0])
data_val = data[mask]
mask = range(int(data.shape[0]*0.8))
data_train = data[mask]

train_dset = tf.data.Dataset.from_tensor_slices(data_train).\
                                 shuffle(buffer_size=10000).\
                                repeat(1).batch(BATCH_SIZE)
val_dset = tf.data.Dataset.from_tensor_slices(data_val).\
                                 batch(BATCH_SIZE)


loss_metric = tf.keras.metrics.Mean(name='train_loss')
optimizer = tf.keras.optimizers.Adam(0.001)

@tf.function
def train_step(inputs):
    ...

for epoch in range(NUM_EPOCHS):
    # Reset the metrics
    loss_metric.reset_states()
    for inputs in train_dset:
        train_step(inputs)
    ...

需要重新洗牌的批次:

train_dset = tf.data.Dataset.from_tensor_slices(data_train).\
                                repeat(1).batch(BATCH_SIZE)

train_dset = train_dset.shuffle(buffer_size=buffer_size)