如何从 Tensorflow tf.data.Dataset 中无休止地阅读?
How can I read endlessly from a Tensorflow tf.data.Dataset?
我正在将我的旧数据层(使用队列)切换到 "new" 并推荐数据集 API。我是第一次使用它,所以我提供了代码示例以防我遇到根本性错误。
我从一个生成器创建我的数据集(它将读取一个文件,并提供 n 个样本)。这是一个小数据集 n_iterations >> n_samples,所以我只是想一遍又一遍地阅读这个数据集,最好是打乱顺序。
sample_set = tf.data.Dataset.from_generator( data_generator(filename),
(tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1]))
)
使用数据生成器:
class data_generator:
def __init__(self, filename):
self.filename= filename
def __call__(self):
with filename.open() as f:
for idx in f: yield img[idx], label[idx]
要实际使用数据,我知道我需要定义一个 Iterator
sample = sample_set.make_one_shot_iterator().get_next()
然后我们设置读取数据
while True:
try: my_sample = sess.run(sample)
except tf.errors.OutOfRangeError: break # this happens after dset is read once
但所有可用的迭代器似乎都是 "finite",因为它们只读取数据集一次。
是否有一种简单的方法可以无限地读取数据集?
reinitializable Iterator 将在同一数据集上重新初始化,因此此代码将一遍又一遍地读取同一数据集:
sample = tf.data.Iterator.from_structure(sample_set.output_types,
sample_set.output_shapes).get_next()
sample_it.make_initializer(sample_set) # create initialize op
with tf.Session(config=config) as sess:
sess.run(sample_set_init_op) # initialize in the beginning
while True:
try:
my_sample = sess.run(sample)
except tf.errors.OutOfRangeError:
sess.run(sample_set_init_op) # re-initialize on same dataset
BUF_SIZE = 100 # choose it depending on your data
sample_set = tf.data.Dataset.from_generator( data_generator(filename),
(tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]),
tf.TensorShape([256,256,1]))
).repeat().shuffle(BUF_SIZE)
如果您不向其传递显式 count
,Dataset.repeat()
转换将无休止地重复数据集:
sample_set = tf.data.Dataset.from_generator(
data_generator(filename), (tf.uint8, tf.uint8),
(tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1])))
# Repeats `sample_set` endlessly.
sample_set = sample_set.repeat()
sample = sample_set.make_one_shot_iterator().get_next()
我正在将我的旧数据层(使用队列)切换到 "new" 并推荐数据集 API。我是第一次使用它,所以我提供了代码示例以防我遇到根本性错误。
我从一个生成器创建我的数据集(它将读取一个文件,并提供 n 个样本)。这是一个小数据集 n_iterations >> n_samples,所以我只是想一遍又一遍地阅读这个数据集,最好是打乱顺序。
sample_set = tf.data.Dataset.from_generator( data_generator(filename),
(tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1]))
)
使用数据生成器:
class data_generator:
def __init__(self, filename):
self.filename= filename
def __call__(self):
with filename.open() as f:
for idx in f: yield img[idx], label[idx]
要实际使用数据,我知道我需要定义一个 Iterator
sample = sample_set.make_one_shot_iterator().get_next()
然后我们设置读取数据
while True:
try: my_sample = sess.run(sample)
except tf.errors.OutOfRangeError: break # this happens after dset is read once
但所有可用的迭代器似乎都是 "finite",因为它们只读取数据集一次。
是否有一种简单的方法可以无限地读取数据集?
reinitializable Iterator 将在同一数据集上重新初始化,因此此代码将一遍又一遍地读取同一数据集:
sample = tf.data.Iterator.from_structure(sample_set.output_types,
sample_set.output_shapes).get_next()
sample_it.make_initializer(sample_set) # create initialize op
with tf.Session(config=config) as sess:
sess.run(sample_set_init_op) # initialize in the beginning
while True:
try:
my_sample = sess.run(sample)
except tf.errors.OutOfRangeError:
sess.run(sample_set_init_op) # re-initialize on same dataset
BUF_SIZE = 100 # choose it depending on your data
sample_set = tf.data.Dataset.from_generator( data_generator(filename),
(tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]),
tf.TensorShape([256,256,1]))
).repeat().shuffle(BUF_SIZE)
如果您不向其传递显式 count
,Dataset.repeat()
转换将无休止地重复数据集:
sample_set = tf.data.Dataset.from_generator(
data_generator(filename), (tf.uint8, tf.uint8),
(tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1])))
# Repeats `sample_set` endlessly.
sample_set = sample_set.repeat()
sample = sample_set.make_one_shot_iterator().get_next()