Tensorflow Dataset API 是否完全摆脱了 feed_dict 论点?

Does Tensorflow Dataset API totally get rid of feed_dict argument?

我开始使用数据集 API 来替换 feed_dict 系统。

但是,在创建数据集管道后,如何在不使用 feed_dict 的情况下将数据集的数据提供给模型?

首先,我创建了一个一次性迭代器。但在这种情况下,您需要使用 feed_dict 将来自迭代器的数据提供给模型。

其次,我尝试直接从 tf.placeholder 创建我的数据集,然后使用 initializable_iterator。但是又一次,我不明白如何摆脱feed_dict。另外,我不明白这种基于占位符的数据集的目的是什么。

我的基本模型:

x = tf.placeholder(tf.float32, [None, 2])
dense = tf.layers.dense(x, 1)
init_dense = tf.global_variables_initializer()

我的数据:

np_data = np.random.sample((100,2))

方法一:

dataset = tf.data.Dataset.from_tensor_slices(np_data)
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()

with tf.Session() as sess:
  sess.run(init_glob)

  for i in range(100):
    value = sess.run(next_value)
    # Cannot get rid of feed_dict
    result = sess.run(dense, feed_dict({x: value})

方法二:

dataset = tf.data.Dataset.from_tensor_slices(x)
iterator = dataset.make_initializable_iterator()
next_value = iterator.get_next()

with tf.Session() as sess:
  sess.run(init_glob)
  sess.run(iterator.initializer, feed_dict={x: np_data})

  for i in range(100):
    value = sess.run(next_value)
    # Cannot get rid of feed_dict
    result = sess.run(dense, feed_dict({x: value})

https://www.tensorflow.org/guide/performance/overview#input_pipeline

那么,我怎样才能 "Avoid using feed_dict for all but trivial examples" 呢? 我想我没有理解数据集的概念API

是的,如果使用数据集api,我们不需要使用feed_dict

相反,我们每次都可以将致密层应用到 next_value

像这样:

def model(x):
  dense = tf.layers.dense(x, 1)
  return dense

result_for_this_iteration = model(next_value)

所以你的完整玩具示例可能看起来像这样:

def model(x):
  dense = tf.layers.dense(x, 10)
  return dense

dataset = tf.data.Dataset.from_tensor_slices(np.random.sample((100, 2, 2)))
iterator = dataset.make_one_shot_iterator()
next_value = iterator.get_next()

result_for_this_iteration = model(next_value)


with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  while(True):
    try:
      result = sess.run(result_for_this_iteration)
      print (result)
    except OutOfRangeError:
      print ("no more data")

当然,其他配置选项比比皆是。我们可以 repeat() 这样我们就不会到达数据的末尾而是循环遍历它。我们可以 batch(n) 分成大小 n 的批次。我们可以 map(pre_process) 对每个元素应用 pre_process 函数,等等