tensorflow 数据集的返回大小 API 不是常量

returned size of tensorflow's dataset API is not constant

我正在使用 tensorflow 的 dataset API。并用简单的案例测试我的代码。下面显示了我使用的简单代码。问题是,当数据集较小时,数据集 API 返回的大小似乎不一致。我相信有一个正确的方法来处理它。但即使我阅读了该页面和教程中的所有功能,我也找不到。

import numpy as np
import tensorflow as tf

data_source = tf.zeros([24, 200, 64, 64, 1]) #[number_of_video, steps, pixel_w, pixel_h, channel]
dataset = tf.contrib.data.Dataset.from_tensor_slices(data_source)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(16)
dataset = dataset.repeat()

iterator = tf.contrib.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
next_element = iterator.get_next()
training_init_op = iterator.make_initializer(dataset)

with tf.Session() as sess:
    sess.run(training_init_op)
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))
    next_elem = next_element.eval()
    print(np.shape(next_elem))

数据集是灰度视频。共有24个视频序列,步长均为200。帧大小为64×64,单通道。我将批量大小设置为 16,将缓冲区大小设置为 100。但是代码的结果是,

(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(8, 200, 64, 64, 1)
(16, 200, 64, 64, 1)

返回的video size不是16就是8,我猜是因为原来的data size比较小,24,到了data的末尾,API就returns剩下什么。

但是我不明白。我还将缓冲区大小设置为 100。这意味着缓冲区应该预先用小数据集填充。从该缓冲区中,API 应该 select next_element 其批量大小为 16。

我在tensorflow中使用queue-typeAPI的时候没有出现这个问题。无论原始数据的大小如何,总有一个迭代器到达数据集末尾的时刻。我想知道其他人是如何使用这个 API 解决这个问题的。

尝试在 batch() 之前调用 repeat():

data_source = tf.zeros([24, 200, 64, 64, 1]) #[number_of_video, steps, pixel_w, pixel_h, channel]
dataset = tf.contrib.data.Dataset.from_tensor_slices(data_source)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.repeat()
dataset = dataset.batch(16)

我得到的结果:

(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)
(16, 200, 64, 64, 1)

您可以使用以下代码解决问题:

batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128))