tensorflow 数据集滑动 window 批处理不起作用?
tensorflow dataset sliding window batch not working?
我无法让这段代码工作,我哪里错了?
dataset = tf.data.Dataset.from_tensors(np.arange(8))
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=4))
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()
with tf.Session() as sess:
while True:
try:
print(sess.run(element))
except tf.errors.OutOfRangeError:
print('end')
break
我本以为 [0,1,2,3],[1,2,3,4],...
但我什么也没得到。
编辑:
如果我在 apply
之前执行 print(dataset)
我会得到 <TensorDataset shapes: (8,), types: tf.int64>
,在 apply
之后我会得到 <_SlideDataset shapes: (?, 8), types: tf.int64>
,这不是我所期望的:不应该是 _SlideDataset
是 (?, 4)
?
将代码从 from_tensors
更改为 from_tensor_slices
。请参阅下面的代码更新:
import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices((np.arange(8)))
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=4))
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()
with tf.Session() as sess:
while True:
try:
print(sess.run(element))
except tf.errors.OutOfRangeError:
print('end')
break
输出:
[0 1 2 3]
[1 2 3 4]
[2 3 4 5]
[3 4 5 6]
[4 5 6 7]
end
我无法让这段代码工作,我哪里错了?
dataset = tf.data.Dataset.from_tensors(np.arange(8))
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=4))
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()
with tf.Session() as sess:
while True:
try:
print(sess.run(element))
except tf.errors.OutOfRangeError:
print('end')
break
我本以为 [0,1,2,3],[1,2,3,4],...
但我什么也没得到。
编辑:
如果我在 apply
之前执行 print(dataset)
我会得到 <TensorDataset shapes: (8,), types: tf.int64>
,在 apply
之后我会得到 <_SlideDataset shapes: (?, 8), types: tf.int64>
,这不是我所期望的:不应该是 _SlideDataset
是 (?, 4)
?
将代码从 from_tensors
更改为 from_tensor_slices
。请参阅下面的代码更新:
import tensorflow as tf
import numpy as np
dataset = tf.data.Dataset.from_tensor_slices((np.arange(8)))
dataset = dataset.apply(tf.contrib.data.sliding_window_batch(window_size=4))
iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()
with tf.Session() as sess:
while True:
try:
print(sess.run(element))
except tf.errors.OutOfRangeError:
print('end')
break
输出:
[0 1 2 3]
[1 2 3 4]
[2 3 4 5]
[3 4 5 6]
[4 5 6 7]
end