具有多个输入的 Tensorflow 图没有 tf.placeholder 用于验证

Tensorflow graph with multiple inputs wihtout tf.placeholder for validation

我使用 tf.data API 作为我的模型。现在我将 tf.data 迭代器的输出定义为我的网络的输入。在我摆脱了 feed_dict 方法之后,我的表现有了显着的提高。

现在我想实现一个验证集,每次训练后至少 运行s 一次。有没有办法为 tf.data 实施验证 运行,或者我是否必须设置占位符并手动切换 tf.datasets 并再次使用 feed_dicts?验证测试的推荐方法是什么?

hack-ish 方式 - 节点替换

最简单的方法,虽然绝对不是最漂亮的,但就是使用 tf.data API 创建的节点作为 feed_dict 的输入——这是因为在 Tensorflow 中,您可以通过将其值直接提供给 feed_dict.

来替换计算图中任何节点的值

所以这会是这样的

batch_input = tf_train_data_foo()
validation_input = tf_validation_data_foo()

model = build_model(batch_input)
optimization_step = some_optimization_foo(model)

# Regular train
session.run(optimization_step)

# Validation run
validation_data = session.run(validation_input)
session.run(model, {batch_input: validation_data})

更好的方法——变量重用

如果所有构造都使用 tf.get_variable 而不是创建新变量,并且范围都设置为能够获取现有变量,则您只需调用 build_model 函数两次 - 一次使用训练数据(来自 tf.data)和一次验证数据。您可以在

上查看有关变量重用的更多详细信息