让 Tensorflow 只考虑连续的数据行

Make Tensorflow only consider consecutive rows of data

我的数据由时间戳和其他一些数据字段组成。然而,我的一些条目不被考虑用于学习,因此我将它们从数据中删除。我最终得到了这样的数据集:

1
2
3
4
5
7
8
9
10

(注意 6 处的差距)。现在,为了学习,我希望 Tensorflow 考虑最后 3 行(并从下一行获取预测标签),但要考虑间隙。在我的示例中,有效数据包将是 (1,2,3,4)、(2,3,4,5) 和 (7,8,9,10),但不是例如(3,4,5,7).

我研究了 Tensorflow API,似乎 Datasets 自己的实现可能会成功,虽然乍一看 class 没有看起来很适合这种方法(例如,没有抽象超级 class,其中只有一些微小的 next() 方法必须实现 ;-))。

还有其他想法吗?你会如何解决这个问题?

我认为最直接的方法是使用 tf.data.Dataset API 的窗口功能,并过滤相关值。

例如,如果重复使用您的示例:

# creating a dataset of the values 1 to 10
ds = tf.data.Dataset.range(1,11)
# elements that we don't want in the dataset
to_remove = tf.constant([6])
# creating windows of size 4 with a shfit of 1. We keep only windows of size 4
windows = ds.window(size=4, shift=1, drop_remainder=True)
# window returns a Dataset of Dataset, we flatten it to get a Dataset of Tensor
windows = windows.flat_map(lambda window: window.batch(4, drop_remainder=True))
# we filter to keep only the correct elements
filtered = windows.filter(lambda x: not tf.reduce_any(tf.equal(x,to_remove[:,tf.newaxis])))

如果我们查看最终数据集:

>>> for data in filtered:
        print(data)
tf.Tensor([1 2 3 4], shape=(4,), dtype=int32)
tf.Tensor([2 3 4 5], shape=(4,), dtype=int32)
tf.Tensor([ 7  8  9 10], shape=(4,), dtype=int32)