Tensorflow Dataset API 是否完全摆脱了 feed_dict 论点?
Does Tensorflow Dataset API totally get rid of feed_dict argument?
我开始使用数据集 API 来替换 feed_dict 系统。
但是,在创建数据集管道后,如何在不使用 feed_dict 的情况下将数据集的数据提供给模型?
首先,我创建了一个一次性迭代器。但在这种情况下,您需要使用 feed_dict 将来自迭代器的数据提供给模型。
其次,我尝试直接从 tf.placeholder 创建我的数据集,然后使用 initializable_iterator。但是又一次,我不明白如何摆脱feed_dict。另外,我不明白这种基于占位符的数据集的目的是什么。
我的基本模型:
x = tf.placeholder(tf.float32, [None, 2])
dense = tf.layers.dense(x, 1)
init_dense = tf.global_variables_initializer()
我的数据:
np_data = np.random.sample((100,2))
方法一:
dataset = tf.data.Dataset.from_tensor_slices(np_data)
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()
with tf.Session() as sess:
sess.run(init_glob)
for i in range(100):
value = sess.run(next_value)
# Cannot get rid of feed_dict
result = sess.run(dense, feed_dict({x: value})
方法二:
dataset = tf.data.Dataset.from_tensor_slices(x)
iterator = dataset.make_initializable_iterator()
next_value = iterator.get_next()
with tf.Session() as sess:
sess.run(init_glob)
sess.run(iterator.initializer, feed_dict={x: np_data})
for i in range(100):
value = sess.run(next_value)
# Cannot get rid of feed_dict
result = sess.run(dense, feed_dict({x: value})
https://www.tensorflow.org/guide/performance/overview#input_pipeline
那么,我怎样才能 "Avoid using feed_dict for all but trivial examples" 呢?
我想我没有理解数据集的概念API
是的,如果使用数据集api,我们不需要使用feed_dict
。
相反,我们每次都可以将致密层应用到 next_value
。
像这样:
def model(x):
dense = tf.layers.dense(x, 1)
return dense
result_for_this_iteration = model(next_value)
所以你的完整玩具示例可能看起来像这样:
def model(x):
dense = tf.layers.dense(x, 10)
return dense
dataset = tf.data.Dataset.from_tensor_slices(np.random.sample((100, 2, 2)))
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()
result_for_this_iteration = model(next_value)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
while(True):
try:
result = sess.run(result_for_this_iteration)
print (result)
except OutOfRangeError:
print ("no more data")
当然,其他配置选项比比皆是。我们可以 repeat()
这样我们就不会到达数据的末尾而是循环遍历它。我们可以 batch(n)
分成大小 n
的批次。我们可以 map(pre_process)
对每个元素应用 pre_process
函数,等等
我开始使用数据集 API 来替换 feed_dict 系统。
但是,在创建数据集管道后,如何在不使用 feed_dict 的情况下将数据集的数据提供给模型?
首先,我创建了一个一次性迭代器。但在这种情况下,您需要使用 feed_dict 将来自迭代器的数据提供给模型。
其次,我尝试直接从 tf.placeholder 创建我的数据集,然后使用 initializable_iterator。但是又一次,我不明白如何摆脱feed_dict。另外,我不明白这种基于占位符的数据集的目的是什么。
我的基本模型:
x = tf.placeholder(tf.float32, [None, 2])
dense = tf.layers.dense(x, 1)
init_dense = tf.global_variables_initializer()
我的数据:
np_data = np.random.sample((100,2))
方法一:
dataset = tf.data.Dataset.from_tensor_slices(np_data)
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()
with tf.Session() as sess:
sess.run(init_glob)
for i in range(100):
value = sess.run(next_value)
# Cannot get rid of feed_dict
result = sess.run(dense, feed_dict({x: value})
方法二:
dataset = tf.data.Dataset.from_tensor_slices(x)
iterator = dataset.make_initializable_iterator()
next_value = iterator.get_next()
with tf.Session() as sess:
sess.run(init_glob)
sess.run(iterator.initializer, feed_dict={x: np_data})
for i in range(100):
value = sess.run(next_value)
# Cannot get rid of feed_dict
result = sess.run(dense, feed_dict({x: value})
https://www.tensorflow.org/guide/performance/overview#input_pipeline
那么,我怎样才能 "Avoid using feed_dict for all but trivial examples" 呢? 我想我没有理解数据集的概念API
是的,如果使用数据集api,我们不需要使用feed_dict
。
相反,我们每次都可以将致密层应用到 next_value
。
像这样:
def model(x):
dense = tf.layers.dense(x, 1)
return dense
result_for_this_iteration = model(next_value)
所以你的完整玩具示例可能看起来像这样:
def model(x):
dense = tf.layers.dense(x, 10)
return dense
dataset = tf.data.Dataset.from_tensor_slices(np.random.sample((100, 2, 2)))
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()
result_for_this_iteration = model(next_value)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
while(True):
try:
result = sess.run(result_for_this_iteration)
print (result)
except OutOfRangeError:
print ("no more data")
当然,其他配置选项比比皆是。我们可以 repeat()
这样我们就不会到达数据的末尾而是循环遍历它。我们可以 batch(n)
分成大小 n
的批次。我们可以 map(pre_process)
对每个元素应用 pre_process
函数,等等