不完整批次的 Tensorflow 训练

Tensorflow Train on incomplete batch

我正在尝试在 tensorflow 中进行批处理训练。这有点管用,因为我可以分批进行第一个纪元。 我的代码目前有 2 个问题。
1. 第一个纪元完成后,第二个纪元立即进入 except tf.errors.OutOfRangeError 并且下一个纪元不会从顶部重新开始批处理。我怎样才能做另一个 epoch 再次批次?
2. 我打印了 batchnr,我注意到纪元的最后一批打印了 print(batchnr) 但没有打印 print(End batchnr) 并且转到了 except 并且没有接受训练。这是因为队列中剩余的行数小于我猜测的批量大小。我怎样才能训练最后一批?

我的train方法和pipeline方法

def input_pipeline(file, batch_size, num_epochs=None):
  filename_queue = tf.train.string_input_producer([file], num_epochs=num_epochs, shuffle=True)
  example, label = read_from_csv(filename_queue)
  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * 2
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)
  return example_batch, label_batch

def train():
    examples, labels = input_pipeline(training_data_file, batch_size, 1)
    saver = tf.train.Saver()
    prediction = neural_network_model(p_inputdata)
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=p_known_labels))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)

    init = tf.group(tf.initialize_all_variables(),
                    tf.initialize_local_variables())
    with tf.Session() as sess:
        sess.run(init) # initialize all variables **in** the session

        correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(p_known_labels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct, 'float'))

        latest_cost_of_batch = None
        for e in range(epochs):
            epoch = e + 1
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            try:
                batchnr = 1
                while not coord.should_stop():
                    print(batchnr)
                    batch_data, batch_labels = sess.run([examples, labels])
                    batch_labels_output = get_output_values(batch_labels)
                    print("End", batchnr)
                    batchnr += 1

                    _, latest_cost_of_batch = sess.run([optimizer,cost], feed_dict={
                        p_inputdata: batch_data,
                        p_known_labels: batch_labels_output
                    })

            except tf.errors.OutOfRangeError:
                print('Done training, epoch reached')
                if (epoch) % print_each_x_number_of_epochs == 0 or epoch == 0:
                    print('Epoch', epoch, 'completed out of', epochs, "---", 'Loss', latest_cost_of_batch)
                if epoch % save_each_x_number_of_epochs == 0:
                    saver.save(sess, checkpoint_label)
            finally:
                coord.request_stop()
        coord.join(threads)

        print("Trained for ", epochs,"epochs. Saving variables...")
        saver.save(sess, checkpoint_label)
        print("Variables saved. Training finished.")
    end = time.time()
    seconds = end - start
    print("Total runtime:", str(datetime.timedelta(seconds=seconds)))

调试控制台

Start training
1
End 1
2
End 2
....
213
End 213
214
Done training, epoch reached
Epoch 1 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 2 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 3 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 4 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 5 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 6 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 7 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 8 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 9 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 10 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 11 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 12 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 13 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 14 completed out of 15 --- Loss 4.43414
1
Done training, epoch reached
Epoch 15 completed out of 15 --- Loss 4.43414
Trained for  15 epochs. Saving variables...
Variables saved. Training finished.
Accuracy 0.935310311615 % after 15 epochs of training
Total runtime: 0:00:21.395917

编辑
我根据 Nicolas 的回答更改了代码(我在 string_input_producer 中使用了多个时期)。现在我要训练以下代码:

def train():
    """Trains the neural network  
    """
    examples, labels = input_pipeline(training_data_file, batch_size, epochs)
    start = time.time()
    saver = tf.train.Saver()
    prediction = neural_network_model(p_inputdata)
    first_no_loss = True
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels=p_known_labels))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)

    init = tf.group(tf.initialize_all_variables(),
                    tf.initialize_local_variables())
    with tf.Session() as sess:
        sess.run(init) # initialize all variables **in** the session
        correct = tf.equal(tf.argmax(prediction, 1), tf.argmax(p_known_labels, 1))
        accuracy = tf.reduce_mean(tf.cast(correct, 'float'))

        print("Start training")
        latest_cost_of_batch = None

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        epoch_op = "input_producer/limit_epochs/epochs:0"
        try:
            batchnr = 1
            epochs_var = 0
            while not coord.should_stop():
                if (batchnr) % print_each_x_number_of_batches == 0:
                    print('Batch', batchnr, 'completed of epoch', epochs_var, "---", 'Loss', latest_cost_of_batch)

                if  batchnr > 3194:
                    print("GETTING BATCH", batchnr)
                epochs_var, batch_data, batch_labels = sess.run([epoch_op, examples, labels])
                batch_labels_output = get_output_values(batch_labels)
                if  batchnr > 3194:
                    print("GOT BATCH", batchnr)
                batchnr += 1
                _, latest_cost_of_batch = sess.run([optimizer,cost], feed_dict={
                    p_inputdata: batch_data,
                    p_known_labels: batch_labels_output
                })

        except tf.errors.OutOfRangeError:
            print('Done training, epoch reached')
        finally:
            coord.request_stop()

        coord.join(threads)

        print("Trained for ", epochs,"epochs. Saving variables...")
        saver.save(sess, checkpoint_label)
        print("Variables saved. Training finished.")
        labels, values, output = get_training_or_testdata(training_data_file)
        print('Accuracy', accuracy.eval(feed_dict={p_inputdata: values, p_known_labels: output}) * 100, '% after', epochs, 'epochs of training')
    end = time.time()
    seconds = end - start
    print("Total runtime:", str(datetime.timedelta(seconds=seconds)))

我的输出是这样的

Start training
Batch 100 completed of epoch 15 --- Loss 4.79351
Batch 200 completed of epoch 15 --- Loss 4.57468
Batch 300 completed of epoch 15 --- Loss 4.51134
Batch 400 completed of epoch 15 --- Loss 4.65865
Batch 500 completed of epoch 15 --- Loss 4.55456
Batch 600 completed of epoch 15 --- Loss 4.63549
Batch 700 completed of epoch 15 --- Loss 4.53037
Batch 800 completed of epoch 15 --- Loss 4.49263
Batch 900 completed of epoch 15 --- Loss 4.37
Batch 1000 completed of epoch 15 --- Loss 4.42719
Batch 1100 completed of epoch 15 --- Loss 4.4518
Batch 1200 completed of epoch 15 --- Loss 4.41053
Batch 1300 completed of epoch 15 --- Loss 4.43508
Batch 1400 completed of epoch 15 --- Loss 4.32173
Batch 1500 completed of epoch 15 --- Loss 4.36624
Batch 1600 completed of epoch 15 --- Loss 4.44027
Batch 1700 completed of epoch 15 --- Loss 4.37201
Batch 1800 completed of epoch 15 --- Loss 4.24956
Batch 1900 completed of epoch 15 --- Loss 4.40256
Batch 2000 completed of epoch 15 --- Loss 4.18391
Batch 2100 completed of epoch 15 --- Loss 4.30156
Batch 2200 completed of epoch 15 --- Loss 4.38423
Batch 2300 completed of epoch 15 --- Loss 4.23823
Batch 2400 completed of epoch 15 --- Loss 4.17783
Batch 2500 completed of epoch 15 --- Loss 4.31024
Batch 2600 completed of epoch 15 --- Loss 4.26312
Batch 2700 completed of epoch 15 --- Loss 4.26143
Batch 2800 completed of epoch 15 --- Loss 4.16691
Batch 2900 completed of epoch 15 --- Loss 4.48624
Batch 3000 completed of epoch 15 --- Loss 4.1347
Batch 3100 completed of epoch 15 --- Loss 4.20801
GETTING BATCH 3195
GOT BATCH 3195
GETTING BATCH 3196
GOT BATCH 3196
GETTING BATCH 3197
Done training, epoch reached
Trained for  15 epochs. Saving variables...
Variables saved. Training finished.
Accuracy 2.69019026309 % after 15 epochs of training
Total runtime: 0:03:07.577149

我注意到的事情是,最后一批仍然没有得到训练(GOT BATCH 3197 没有得到打印),其次,获取当前纪元的方法不正确。它始终是 15。 解释了为什么我现在这样做的方式不是正确的方法,但它没有解释获取当前纪元的正确方法。有什么线索吗?


编辑:您可能想看看这个 ,它给出了新 API.

的示例

这是对你得到的内容的解释。

  • 第一次执行 for e in range(epochs) 循环时,它会从数据队列中取出所有内容(直到数据队列抛出 tf.errors.OutOfRangeError)。

    当文件名队列中没有更多文件名时抛出此错误。仅读取一次文件后会发生这种情况,这是因为您调用了 examples, labels = input_pipeline(training_data_file, batch_size, 1).

    例如,如果您调用了 examples, labels = input_pipeline(training_data_file, batch_size, 3),那么在转到 e=1.

  • 之前,您将浏览文件 3 次
  • 然后当你移动到 e>0 时,文件名队列保存在内存中,你已经将所有文件名从队列中取出,并且由于没有更多的入队操作,它抛出 tf.errors.OutOfRangeError直接。

    查看字符串文档:

    Note: if num_epochs is not None, this function creates local counter epochs. Use local_variables_initializer() to initialize local variables.

你能做什么?

  1. 您在 for e in range(epochs) 循环中移动会话上下文管理器:

    init_queue = tf.variables_initializer(tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope='input_producer'))`
    with tf.Session() as sess:
        sess.run(init)
    for e in range(EPOCHS):
        with tf.Session() as sess:
            sess.run(init_queue) # initialize all local variables **in** the the input_producer scope
            epoch = e + 1
    

    这意味着您将重新初始化 input_producer 范围内的所有局部变量,因此您需要注意它们是什么。 您还可以保存您的模型并在每一步重新加载它,或者

  2. 你依靠 num_epochs 参数来 运行 正确的纪元数并删除你的 for e in range(EPOCHS) 循环。您可以每 100 或 1000 个训练步骤打印一次信息,而不是在每个时期结束时打印信息(我最喜欢的解决方案)。如果你真的想在每个纪元结束时打印信息,你可以尝试访问隐藏的 epochs 变量,评估它的值并在有 'epochs' 变化时打印信息(我不推荐这个选项).

例如:

    batchnr = 0
    tmp_batchnr = 0
    while not coord.should_stop():
            if batchnr != tmp_batchnr:
                print(....)
                batchnr = tmp_batchnr
            epochs_var, _, _ = sess.run([epochs_var, examples, labels])
            print("End", batchnr)
            batchnr += 1

希望对您有所帮助!

对已编辑问题的评论:

从你提到的答案中看这句话中强调的内容,在我看来你无法知道出队属于哪个时代。

When tf.start_queue_runners() is executed, all the epochs are enqueued together (in multiple stages if capacity is less than number of filenames). The local variable epochs:0 is used by tf.train.string_input_producer to maintain the epoch that is being enqueued. Once epochs:0 reaches num_epochs, it remains constant and no matter how many threads are dequeuing from the queue, it does not change.