使用 tf.cond() 在 TensorFlow 中查找错误

LookUpError in TensorFlow with tf.cond()

工作环境

问题描述

我使用 tf.cond() 在处理时在训练和验证数据集之间移动。以下代码段显示了我是如何完成的:

with tf.variable_scope(tf.get_variable_scope()) as vscope:
        for i in range(4):
            with tf.device('/gpu:%d'%i):
                with tf.name_scope('GPU-Tower-%d'%i) as scope:
                    worktype = tf.get_variable("wt",[], initializer=tf.zeros_initializer())
                    worktype = tf.assign(worktype, 1)
                    workcondition = tf.equal(worktype, 1)
                    elem = tf.cond(workcondition, lambda: train_iterator.get_next(), lambda: val_iterato\
r.get_next())
                    net =  vgg16cnn2(elem[0],numclasses=256)
                    img = elem[0]
                    centropy  = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=ele\
m[1],logits= net))
                    reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope)
                    regloss = 0.05 * tf.reduce_sum(reg_losses)
                    total_loss = centropy + regloss
                    t1 = tf.summary.scalar("Training Batch Loss", total_loss)
                    tf.get_variable_scope().reuse_variables()
                    predictions = tf.cast(tf.argmax(tf.nn.softmax(net), 1), tf.int32)
                    correct_predictions = tf.cast(tf.equal(predictions, elem[1]), tf.float32)
                    batch_accuracy = tf.reduce_mean(correct_predictions)
                    t2 = tf.summary.scalar("Training Batch Accuracy", batch_accuracy)
                    correct_detection.append(correct_predictions)
                    grads = optim.compute_gradients(total_loss)

所以基本上根据worktype的值,从训练或验证集中提取一个小批量。

当我 运行 这段代码时,我得到以下 LookUp Error :

LookupError: No gradient defined for operation 'GPU-Tower-0/cond/IteratorGetNext_1' (op type: IteratorGetNext)

为什么 TensorFlow 认为 IteratorGetNext_1 需要渐变?我该如何补救?

变量worktype被标记为可训练。默认情况下,Optimizer.compute_gradients(...) 计算所有可训练变量的梯度。

有两种方法可以解决这个问题:

  1. tf.get_variable(...)中设置trainable=False
  2. 明确指定应使用 Optimizer.compute_gradients(...)var_list 参数计算梯度的变量。