如何重塑 Tensorflow 数据集中的数据?

How to reshape data in Tensorflow dataset?

我正在编写一个数据管道,以将成批的时间序列序列和相应的标签提供给需要 3D 输入形状的 LSTM 模型。我目前有以下内容:

def split(window):
    return window[:-label_length], window[-label_length]

dataset = tf.data.Dataset.from_tensor_slices(data.sin)
dataset = dataset.window(input_length + label_length, shift=label_shift, stride=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(input_length + label_length))
dataset = dataset.map(split, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.shuffle(shuffle_buffer, seed=shuffle_seed, reshuffle_each_iteration=False)
dataset = dataset.batch(batch_size=batch_size, drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

for x, y in dataset.take(1): x.shape 的结果形状是 (32, 20),其中 32 是批量大小,20 是序列长度,但我需要 (32, 20, 1) 的形状,其中附加维度表示特征。

我的问题是如何重塑,最好是在缓存数据之前传递给 dataset.map 函数的 split 函数中?

这很简单。在您的拆分函数中执行此操作

def split(window):
    return window[:-label_length, tf.newaxis], window[-label_length, tf.newaxis, tf.newaxis]