等同于 tf.identity 对操作节点具有控制依赖性

Equivalent of tf.identity with control dependency for an operation node

我正在编写一个包装器 class,它采用具有特殊成员的通用图 "train_op" 来管理我的模型的训练、保存和管理。

我想像这样清楚地跟踪训练步骤的生命周期数:

with tf.control_dependencies([ step_add_one ]):
    self.train_op=tf.identity(self.training_graph.train_op )

raise TypeError('Expected binary or unicode string, got %r' e, is_training=True, inputs=None)

我认为这里的问题是 train_op 是 tf.Optimizer.minimize() 的 return,所以它本身不是张量,而是一个运算。

一个明显的解决方法是在 training_graph.loss 上调用 tf.identity,但我失去了一点抽象,因为我必须在外部处理学习率等。而且,我总觉得自己少了什么。

我怎样才能最好地解决这个问题?

您可以使用 tf.group(),它将与操作 张量一起使用。

例如:

x = tf.Variable(1.)
loss = tf.square(x)
optimizer = tf.train.GradientDescentOptimizer(0.1)
train_op = optimizer.minimize(loss)

step = tf.Variable(0)
step_add_one = step.assign_add(1)

with tf.control_dependencies([step_add_one]):
    train_op_2 = tf.group(train_op)

现在当你运行train_op_2时,step的值会递增。


但是,最好的方法(如果您可以修改创建图表的图表)是将参数 global_step 添加到 minimize 函数:

train_op = optimizer.minimize(loss, global_step=step)