tensorflow 数据集滑动 window 批处理不起作用?

tensorflow dataset sliding window batch not working?

我无法让这段代码工作,我哪里错了?

dataset = tf.data.Dataset.from_tensors(np.arange(8))
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=4))
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()

with tf.Session() as sess:
    while True:
        try:
            print(sess.run(element))
        except tf.errors.OutOfRangeError:
            print('end')
            break

我本以为 [0,1,2,3],[1,2,3,4],... 但我什么也没得到。

编辑: 如果我在 apply 之前执行 print(dataset) 我会得到 <TensorDataset shapes: (8,), types: tf.int64>,在 apply 之后我会得到 <_SlideDataset shapes: (?, 8), types: tf.int64>,这不是我所期望的:不应该是 _SlideDataset(?, 4)?

将代码从 from_tensors 更改为 from_tensor_slices。请参阅下面的代码更新:

import tensorflow as tf
import numpy as np

dataset = tf.data.Dataset.from_tensor_slices((np.arange(8)))
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=4))
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()

with tf.Session() as sess:
    while True:
        try:
            print(sess.run(element))
        except tf.errors.OutOfRangeError:
            print('end')
            break

输出:

[0 1 2 3]
[1 2 3 4]
[2 3 4 5]
[3 4 5 6]
[4 5 6 7]
end