将 tf.placeholder 和 feed_dict 替换为 tf.data API

Replacing tf.placeholder and feed_dict with tf.data API

我有一个现有的 TensorFlow 模型,它使用 tf.placeholder 作为模型输入和 tf.Session() 的 feed_dict 参数。运行 馈入数据。以前整个数据集都是通过这种方式读入内存并传递的。

我想使用更大的数据集并利用 tf.data API 的性能改进。我已经从中定义了一个 tf.data.TextLineDataset 和一次性迭代器,但我很难弄清楚如何将数据导入模型以对其进行训练。

起初我试图将 feed_dict 定义为从占位符到 iterator.get_next() 的字典,但这给了我一个错误,指出提要的值不能是 tf.Tensor 对象。更多的挖掘让我明白这是因为 iterator.get_next() 返回的对象已经是图表的一部分,这与你输入 feed_dict 的对象不同——而且我不应该尝试出于性能原因,无论如何都使用 feed_dict 。

所以现在我已经摆脱了输入 tf.placeholder 并将其替换为定义我的模型的 class 构造函数的参数;在我的训练代码中构建模型时,我将 iterator.get_next() 的输出传递给该参数。这看起来有点笨拙,因为它打破了模型定义和 datasets/training 过程之间的分离。我现在收到一条错误消息,表示代表(我相信)我模型输入的张量必须与来自 iterator.get_next().

的张量来自同一张图

我使用这种方法是否正确,只是在设置图形和会话的方式上做错了什么,或者类似的事情? (数据集和模型都是在会话之外初始化的,错误发生在我尝试创建一个之前。)

或者我是否完全不符合这一点,需要做一些不同的事情,比如使用 Estimator API 并在输入函数中定义所有内容?

下面是一些演示最小示例的代码:

import tensorflow as tf
import numpy as np

class Network:
    def __init__(self, x_in, input_size):
        self.input_size = input_size
        # self.x_in = tf.placeholder(dtype=tf.float32, shape=(None, self.input_size))  # Original
        self.x_in = x_in
        self.output_size = 3

        tf.reset_default_graph()  # This turned out to be the problem

        self.layer = tf.layers.dense(self.x_in, self.output_size, activation=tf.nn.relu)
        self.loss = tf.reduce_sum(tf.square(self.layer - tf.constant(0, dtype=tf.float32, shape=[self.output_size])))

data_array = np.random.standard_normal([4, 10]).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(data_array).batch(2)

model = Network(x_in=dataset.make_one_shot_iterator().get_next(), input_size=dataset.output_shapes[-1])

我也花了点时间才明白过来。你在正确的轨道上。整个数据集定义只是图的一部分。我通常将其创建为与我的模型 class 不同的 class,并将数据集传递到模型 class 中。我指定要在命令行上加载的数据集 class,然后动态加载 class,从而模块化地解耦数据集和图形。

请注意,您可以(并且应该)命名数据集中的所有张量,这确实有助于在您通过所需的各种转换传递数据时使事情变得容易理解。

您可以编写简单的测试用例,从 iterator.get_next() 中提取样本并显示它们,您会得到类似 sess.run(next_element_tensor) 的结果,没有 feed_dict,正如您正确指出的那样。

一旦您了解它,您可能会开始喜欢数据集输入管道。它迫使你很好地模块化你的代码,并迫使它成为一个易于单元测试的结构。

确保您阅读了开发者指南,那里有大量示例:

https://www.tensorflow.org/programmers_guide/datasets

我要注意的另一件事是使用此管道处理训练和测试数据集是多么容易。这很重要,因为您经常在训练数据集上执行数据扩充,而不是在测试数据集上执行,from_string_handle 允许您这样做,并且在上面的指南中有清楚的描述。

我收到的原始代码中模型构造函数中的行 tf.reset_default_graph() 导致了它。删除它修复它。