TensorFlow:恢复 RNN 网络后损失猛增

TensorFlow: loss jumps up after restoring RNN net

环境信息

问题

我在恢复我的网络(RNN 字符基础语言模型)时遇到问题。以下是具有相同问题的简化版本。

当我第一次 运行 时,我得到了,例如,这个。

    ...
    step 160: loss = 1.956 (perplexity = 7.069016620211226)
    step 180: loss = 1.837 (perplexity = 6.274748642468816)
    step 200: loss = 1.825 (perplexity = 6.202084762557817)

但是在第二次运行,恢复参数后,我明白了。

    step 220: loss = 2.346 (perplexity = 10.446611983898903)
    step 240: loss = 2.346 (perplexity = 10.446709120339545)
    ...

所有 tf​​ 变量似乎都已正确还原,包括将被馈送到 RNN 的状态。 数据位置也被恢复(来自 'step')。

我也做了一个类似的MNIST识别模型的程序,这个很好用:还原前后的loss是连续的。

是否还有其他参数或状态需要保存和恢复?

    import argparse
    import os
    import tensorflow as tf
    import numpy as np
    import math

    B = 20  # batch size
    H = 200 # size of hidden layer of neurons
    T = 25  # number of time steps to unroll the RNN for
    data_file = 'ptb.train.txt' # any plain text file will do
    checkpoint_dir = "tmp"

    #----------------
    # prepare data
    #----------------
    data = open(data_file, 'r').read()
    chars = list(set(data))
    data_size, vocab_size = len(data), len(chars)
    print('data has {0} characters, {1} unique.'.format(data_size, vocab_size))
    char_to_ix = { ch:i for i,ch in enumerate(chars) }
    ix_to_char = { i:ch for i,ch in enumerate(chars) }

    input_index_raw = np.array([char_to_ix[ch] for ch in data])
    input_index_raw = input_index_raw[0:len(input_index_raw) // T * T]
    input_index_raw_shift = np.append(input_index_raw[1:], input_index_raw[0])
    input_all = input_index_raw.reshape([-1, T])
    target_all = input_index_raw_shift.reshape([-1, T])
    num_packed_data = len(input_all)

    #----------------
    # build model
    #----------------
    class Model(object):
      def __init__(self):
        self.input_ph = tf.placeholder(tf.int32, [None, T], name="input_ph")
        self.target_ph = tf.placeholder(tf.int32, [None, T], name="target_ph")
        embedding = tf.get_variable("embedding", [vocab_size, H], initializer=tf.random_normal_initializer(), dtype=tf.float32)
        # input_ph is B x T.
        # input_embedded is B x T x H.
        input_embedded = tf.nn.embedding_lookup(embedding, self.input_ph)

        cell = tf.contrib.rnn.BasicRNNCell(H)

        self.state_ph = tf.placeholder(tf.float32, (None, cell.state_size), name="state_ph")

        # Make state variable so that it will be saved by the saver.
        self.state = tf.get_variable("state", (B, cell.state_size), initializer=tf.zeros_initializer(), trainable=False, dtype=tf.float32)

        # Construct initial_state according to whether restoring or not.
        self.isRestore = tf.placeholder(tf.bool, shape=(), name="isRestore")
        zero_state = cell.zero_state(B, dtype=tf.float32)
        self.initial_state = tf.cond(self.isRestore, lambda: self.state, lambda: zero_state)

        # input_embedded : B x T x H
        # output: B x T x H
        # state : B x cell.state_size
        output, state_ = tf.nn.dynamic_rnn(cell, input_embedded, initial_state=self.state_ph)
        self.final_state = tf.assign(self.state, state_)

        # reshape to (B * T) x H.
        output_flat = tf.reshape(output, [-1, H])

        # Convert hidden layer's output to vector of logits for each vocabulary.
        softmax_w = tf.get_variable("softmax_w", [H, vocab_size], dtype=tf.float32)
        softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=tf.float32)
        logits = tf.matmul(output_flat, softmax_w) + softmax_b

        # cross_entropy is a vector of length B * T
        cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.reshape(self.target_ph, [-1]), logits=logits)
        self.loss = tf.reduce_mean(cross_entropy)

        optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
        self.global_step = tf.get_variable("global_step", (), initializer=tf.zeros_initializer(), trainable=False, dtype=tf.int32)
        self.training_op = optimizer.minimize(cross_entropy, global_step=self.global_step)

      def train_batch(self, sess, input_batch, target_batch, initial_state):
        final_state_, _, final_loss = sess.run([self.final_state, self.training_op, self.loss], feed_dict={self.input_ph: input_batch, self.target_ph: target_batch, self.state_ph: initial_state})
        return final_state_, final_loss

    # main
    with tf.Session() as sess:
      if not tf.gfile.Exists(checkpoint_dir):
        tf.gfile.MakeDirs(checkpoint_dir)

      batch_stride = num_packed_data // B

      # make model
      model = Model()
      saver = tf.train.Saver()

      # always initialize
      init = tf.global_variables_initializer()
      init.run()

      # restore if necessary
      isRestore = False
      ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
      if ckpt:
        isRestore = True
        last_model = ckpt.model_checkpoint_path
        print("Loading " + last_model)
        saver.restore(sess, last_model)

      # set initial step
      step = tf.train.global_step(sess, model.global_step) + 1
      print("start step = {0}".format(step))

      # fetch initial state
      state =  sess.run(model.initial_state, feed_dict={model.isRestore: isRestore})
      print("Initial state: {0}".format(state))

      while True:
        # prepare batch data
        idx = [(step + x * batch_stride) % num_packed_data for x in range(0, B)]
        input_batch = input_all[idx]
        target_batch = target_all[idx]

        state, last_loss = model.train_batch(sess, input_batch, target_batch, state)

        if step % 20 == 0:
          print('step {0}: loss = {1:.3f} (perplexity = {2})'.format(step, last_loss, math.exp(last_loss)))

        if step % 200 == 0:
          saved_file = saver.save(sess, os.path.join(checkpoint_dir, "model.ckpt"), global_step=step)
          print("Saved to " + saved_file)
          print("Last state: {0}".format(model.state.eval()))
          break;

        step = step + 1

问题已解决。它与 RNN 和 TensorFlow 无关。

我变了

chars = list(set(data))

chars = sorted(set(data))

现在可以使用了。

这是因为 python uses a random hash function 构建集合,并且每次 python 重新启动时,'chars' 都有不同的顺序。