为什么在 tf.keras.dataset 中打乱数据序列会对 tf.fit 和 tf.predict 中的序列顺序产生不同的影响?
Why does shuffling sequences of data in tf.keras.dataset affect the order of sequences differently between tf.fit and tf.predict?
我正在使用时间序列和标签训练 LSTM 深度学习模型。
我生成一个张量流数据集“train_data”和“test_data”
train_data = tf.keras.preprocessing.timeseries_dataset_from_array(
data=data,
targets=None,
sequence_length=total_window_size,
sequence_stride=1,
batch_size=batch_size,
shuffle=is_shuffle).map(split_window).prefetch(tf.data.AUTOTUNE)
然后我用上述数据集训练模型
model.fit(train_data, epochs=epochs, validation_data = test_data, callbacks=callbacks)
然后运行预测得到预测值
train_labels = np.concatenate([y for x, y in train_data], axis=0)
train_predictions = model.predict(train_data)
test_labels = np.concatenate([y for x, y in test_data], axis=0)
test_predictions = model.predict(test_data)
这是我的问题:当我根据预测值绘制 train/test 标签数据时,当我执行 not 随机排列数据集中的序列时,我得到以下图构建步骤:
这里输出和洗牌:
问题为什么会这样?我使用完全相同的源数据集进行训练和预测。数据集应该被洗牌。 TensorFlow 是否有可能随机打乱数据两次,一次在训练期间,另一次用于预测?我尝试提供随机播放种子,但这也没有改变。
每次迭代时,数据集都会被打乱。列表理解后得到的结果与编写 predict
时的顺序不同。如果你不想那样,通过:
shuffle(buffer_size=BUFFER_SIZE, reshuffle_each_iteration=False)
我正在使用时间序列和标签训练 LSTM 深度学习模型。
我生成一个张量流数据集“train_data”和“test_data”
train_data = tf.keras.preprocessing.timeseries_dataset_from_array(
data=data,
targets=None,
sequence_length=total_window_size,
sequence_stride=1,
batch_size=batch_size,
shuffle=is_shuffle).map(split_window).prefetch(tf.data.AUTOTUNE)
然后我用上述数据集训练模型
model.fit(train_data, epochs=epochs, validation_data = test_data, callbacks=callbacks)
然后运行预测得到预测值
train_labels = np.concatenate([y for x, y in train_data], axis=0)
train_predictions = model.predict(train_data)
test_labels = np.concatenate([y for x, y in test_data], axis=0)
test_predictions = model.predict(test_data)
这是我的问题:当我根据预测值绘制 train/test 标签数据时,当我执行 not 随机排列数据集中的序列时,我得到以下图构建步骤:
这里输出和洗牌:
问题为什么会这样?我使用完全相同的源数据集进行训练和预测。数据集应该被洗牌。 TensorFlow 是否有可能随机打乱数据两次,一次在训练期间,另一次用于预测?我尝试提供随机播放种子,但这也没有改变。
每次迭代时,数据集都会被打乱。列表理解后得到的结果与编写 predict
时的顺序不同。如果你不想那样,通过:
shuffle(buffer_size=BUFFER_SIZE, reshuffle_each_iteration=False)