将数据显式提供给 'tf.data.Dataset' 会影响性能吗

Will it impair the performance to explicitly feed data to 'tf.data.Dataset'

我正在实施 RL 算法并使用 tf.data.Dataset(with prefetch) 将数据馈送到神经网络。但是,为了与环境交互,我必须通过 feed_dict 明确地提供数据以采取行动。我想知道将 feed_dictDataset 一起使用是否会影响速度。

这是我的代码的简化版本

# code related to Dataset
ds = tf.data.Dataset.from_generator(buffer, sample_types, sample_shapes)
ds = ds.prefetch(5)
iterator = ds.make_one_shot_iterator()
samples = iterator.get_next(name='samples')
# pass samples to network
# network training, no feed_dict is needed because of Dataset
sess.run([self.opt_op])
# run the actor network to choose an action at the current state.
# manually feed the current state to samples
# will this impair the performance?
action = sess.run(self.action, feed_dict={samples['state']: state})

混合数据集和 feed_dict 没有错。如果您提供给 feed_dict 的状态很大,则可能会导致 GPU 未充分利用,具体取决于数据的大小。但它绝不会与正在使用的数据集有关。

数据集 API 存在的原因之一是避免模型饥饿并提高训练期间的 GPU 利用率。饥饿可能是由于数据从一个位置复制到另一个位置而发生的:磁盘到内存,内存到 GPU 内存,随便你怎么说。 Dataset 尝试尽早开始执行庞大的 IO 操作,以避免在处理下一批时使模型饿死。所以,基本上,Datasets 试图减少批次之间的时间。

在您的情况下,使用 feed_dict 可能不会降低任何性能。无论如何,您似乎通过环境交互中断了执行(因此,可能未充分利用 GPU)。

如果您想确定,请在使用 feed_dict 提供实际状态时对您的性能进行计时,而不是用常量张量替换状态使用并比较速度。