我应该直接 return 数据集还是应该使用 one_shot 迭代器?
Should I return dataset directly or should i use one_shot iterator instead?
我正在使用数据集 API 构建数据管道,但是当我在输入函数中训练多个 GPU 和 return dataset.make_one_shot_iterator().get_next()
时,我得到
ValueError: dataset_fn() must return a tf.data.Dataset when using a tf.distribute.Strategy
我可以直接按照错误消息和 return 数据集进行操作,但我不明白 iterator().get_next()
的目的以及它如何在单 GPU 和多 GPU 上进行训练。
...
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size = batch_size)
dataset = dataset.cache()
dataset = dataset.prefetch(buffer_size=None)
return dataset.make_one_shot_iterator().get_next()
return _input_fn
当使用 tf.data
分配策略时(可以与 Keras 和 tf.Estimator
s 一起使用),你的输入 fn 应该 return a tf.data.Dataset
:
def input_fn():
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size = batch_size)
dataset = dataset.cache()
dataset = dataset.prefetch(buffer_size=None)
return dataset
...use input_fn...
参见 documentation 分发策略。
dataset.make_one_shot_iterator()
在分发策略/更高级别的库之外很有用,例如,如果您正在使用较低级别的库,或者调试/测试数据集。例如,您可以像这样迭代数据集的所有元素:
dataset = ...
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with tf.Session() as sess:
while True:
print(sess.run(get_next))
except tf.errors.OutOfRangeError:
break
我正在使用数据集 API 构建数据管道,但是当我在输入函数中训练多个 GPU 和 return dataset.make_one_shot_iterator().get_next()
时,我得到
ValueError: dataset_fn() must return a tf.data.Dataset when using a tf.distribute.Strategy
我可以直接按照错误消息和 return 数据集进行操作,但我不明白 iterator().get_next()
的目的以及它如何在单 GPU 和多 GPU 上进行训练。
...
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size = batch_size)
dataset = dataset.cache()
dataset = dataset.prefetch(buffer_size=None)
return dataset.make_one_shot_iterator().get_next()
return _input_fn
当使用 tf.data
分配策略时(可以与 Keras 和 tf.Estimator
s 一起使用),你的输入 fn 应该 return a tf.data.Dataset
:
def input_fn():
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size = batch_size)
dataset = dataset.cache()
dataset = dataset.prefetch(buffer_size=None)
return dataset
...use input_fn...
参见 documentation 分发策略。
dataset.make_one_shot_iterator()
在分发策略/更高级别的库之外很有用,例如,如果您正在使用较低级别的库,或者调试/测试数据集。例如,您可以像这样迭代数据集的所有元素:
dataset = ...
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with tf.Session() as sess:
while True:
print(sess.run(get_next))
except tf.errors.OutOfRangeError:
break