如何在TensorFlow 2.0中使用Dataset.window()方法创建的windows?
How to use windows created by the Dataset.window() method in TensorFlow 2.0?
我正在尝试创建一个数据集,它将 return 从时间序列中随机 windows,并使用 TensorFlow 2.0 将下一个值作为目标。
我正在使用 Dataset.window()
,看起来很有前途:
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)
for window in dataset:
print([elem.numpy() for elem in window])
输出:
[0, 1, 2, 3, 4]
[1, 2, 3, 4, 5]
[2, 3, 4, 5, 6]
[3, 4, 5, 6, 7]
[4, 5, 6, 7, 8]
[5, 6, 7, 8, 9]
不过,我想使用最后一个值作为目标。如果每个 window 都是张量,我会使用:
dataset = dataset.map(lambda window: (window[:-1], window[-1:]))
但是,如果我尝试这样做,我会遇到异常:
TypeError: '_VariantDataset' object is not subscriptable
解决方法是这样调用flat_map()
:
dataset = dataset.flat_map(lambda window: window.batch(5))
现在数据集中的每一项都是一个window,所以你可以这样拆分它:
dataset = dataset.map(lambda window: (window[:-1], window[-1:]))
所以完整的代码是:
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(5))
dataset = dataset.map(lambda window: (window[:-1], window[-1:]))
for X, y in dataset:
print("Input:", X.numpy(), "Target:", y.numpy())
输出:
Input: [0 1 2 3] Target: [4]
Input: [1 2 3 4] Target: [5]
Input: [2 3 4 5] Target: [6]
Input: [3 4 5 6] Target: [7]
Input: [4 5 6 7] Target: [8]
Input: [5 6 7 8] Target: [9]
我正在尝试创建一个数据集,它将 return 从时间序列中随机 windows,并使用 TensorFlow 2.0 将下一个值作为目标。
我正在使用 Dataset.window()
,看起来很有前途:
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)
for window in dataset:
print([elem.numpy() for elem in window])
输出:
[0, 1, 2, 3, 4]
[1, 2, 3, 4, 5]
[2, 3, 4, 5, 6]
[3, 4, 5, 6, 7]
[4, 5, 6, 7, 8]
[5, 6, 7, 8, 9]
不过,我想使用最后一个值作为目标。如果每个 window 都是张量,我会使用:
dataset = dataset.map(lambda window: (window[:-1], window[-1:]))
但是,如果我尝试这样做,我会遇到异常:
TypeError: '_VariantDataset' object is not subscriptable
解决方法是这样调用flat_map()
:
dataset = dataset.flat_map(lambda window: window.batch(5))
现在数据集中的每一项都是一个window,所以你可以这样拆分它:
dataset = dataset.map(lambda window: (window[:-1], window[-1:]))
所以完整的代码是:
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))
dataset = dataset.window(5, shift=1, drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(5))
dataset = dataset.map(lambda window: (window[:-1], window[-1:]))
for X, y in dataset:
print("Input:", X.numpy(), "Target:", y.numpy())
输出:
Input: [0 1 2 3] Target: [4]
Input: [1 2 3 4] Target: [5]
Input: [2 3 4 5] Target: [6]
Input: [3 4 5 6] Target: [7]
Input: [4 5 6 7] Target: [8]
Input: [5 6 7 8] Target: [9]