无法在 tensorflow 1.1 中从上一个会话加载 int 变量

Cannot load int variable from previous session in tensorflow 1.1

我读过很多类似的问题,但就是无法正常工作。

我的模型训练得很好,每个时期都在制作检查点文件。我想要它,以便程序可以在重新加载后从纪元 x 继续,并且还可以在每次迭代时打印该纪元上的内容。我可以简单地将数据保存在检查点文件之外,但是我也想这样做是为了让我相信其他所有内容也都已正确存储。

不幸的是,当我重新启动时,epoch/global_step 变量中的值始终为 0。

import tensorflow as tf
import numpy as np
import tensorflow as tf
import numpy as np
# more imports


def extract_number(f): # used to get latest checkpint file
    s = re.findall("epoch(\d+).ckpt",f)
    return (int(s[0]) if s else -1,f)

def restore(init_op, sess, saver): # called to restore or just initialise model
    list = glob(os.path.join("./params/e*"))

    if list:

        file = max(list,key=extract_number)

        saver.restore(sess, file[:-5])


    sess.run(init_op)
    return


with tf.Graph().as_default() as g:

    # build models


    total_batch = data.train.num_examples / batch_size

    epochLimit = 51

    saver = tf.train.Saver()

    init_op = tf.global_variables_initializer()


    with tf.Session() as sess:


        saver = tf.train.Saver()

        init_op = tf.global_variables_initializer()

        restore(init_op, sess, saver)


        epoch = global_step.eval()


        while epoch < epochLimit:

            total_batch = data.train.num_examples / batch_size

            for i in range(int(total_batch)):

                sys.stdout.flush()

                voxels = newData.eval()

                batch_z = np.random.uniform(-1, 1, [batch_size, z_size]).astype(np.float32)

                sess.run(opt_G, feed_dict={z:batch_z, train:True})
                sess.run(opt_D, feed_dict={input:voxels, z:batch_z, train:True})


                with open("out/loss.csv", 'a') as f:
                    batch_loss_G = sess.run(loss_G, feed_dict={z:batch_z, train:False})
                    batch_loss_D = sess.run(loss_D, feed_dict={input:voxels, z:batch_z, train:False})
                    msgOut = "Epoch: [{0}], i: [{1}], G_Loss[{2:.8f}], D_Loss[{3:.8f}]".format(epoch, i, batch_loss_G, batch_loss_D)

                    print(msgOut)

            epoch=epoch+1
            sess.run(global_step.assign(epoch))
            saver.save(sess, "params/epoch{0}.ckpt".format(epoch))

            batch_z = np.random.uniform(-1, 1, [batch_size, z_size]).astype(np.float32)
            voxels = sess.run(x_, feed_dict={z:batch_z})

            v = voxels[0].reshape([32, 32, 32]) > 0
            util.save_binvox(v, "out/epoch{0}.vox".format(epoch), 32)

我还在底部使用 assign 更新了全局步骤变量。有任何想法吗?任何帮助将不胜感激。

当您在恢复后调用 sess.run(init_op) 时,这会将所有变量重置为其初始值。评论那条线,事情应该会起作用。

由于多种原因,我的原始代码是错误的,因为我尝试了很多东西。第一响应者 Alexandre Passos 给出了一个有效的观点,但我相信改变游戏的也是范围的使用(也许?)。

如果对任何人有帮助,以下是有效的更新代码:

import tensorflow as tf
import numpy as np
# more imports


def extract_number(f): # used to get latest checkpint file
    s = re.findall("epoch(\d+).ckpt",f)
    return (int(s[0]) if s else -1,f)

def restore(sess, saver): # called to restore or just initialise model


    list = glob(os.path.join("./params/e*"))

    if list:

        file = max(list,key=extract_number)

        saver.restore(sess, file[:-5])
        return saver, True, sess

    saver = tf.train.Saver()
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    return saver, False , sess


batch_size = 100
learning_rate = 0.0001
beta1 = 0.5
z_size = 100
save_interval = 1

data = dataset.read()

total_batch = data.train.num_examples / batch_size

def fill_queue():
    for i in range(int(total_batch*epochLimit)):
        sess.run(enqueue_op, feed_dict={batch: data.train.next_batch(batch_size)}) # runnig in seperate thread to feed a FIFOqueue



with tf.variable_scope("glob"):
    global_step = tf.get_variable(name='global_step', initializer=0,trainable=False)

# build models

epochLimit = 51

saver = tf.train.Saver()

with tf.Session() as sess:

    saver,rstr,sess = restore(sess, saver)



    with tf.variable_scope("glob", reuse=True):
        epocht = tf.get_variable(name='global_step', trainable=False, dtype=tf.int32)

    epoch = epocht.eval()


    while epoch < epochLimit:

        total_batch = data.train.num_examples / batch_size

        for i in range(int(total_batch)):

            sys.stdout.flush()

            voxels = newData.eval()

            batch_z = np.random.uniform(-1, 1, [batch_size, z_size]).astype(np.float32)

            sess.run(opt_G, feed_dict={z:batch_z, train:True})
            sess.run(opt_D, feed_dict={input:voxels, z:batch_z, train:True})


            with open("out/loss.csv", 'a') as f:
                batch_loss_G = sess.run(loss_G, feed_dict={z:batch_z, train:False})
                batch_loss_D = sess.run(loss_D, feed_dict={input:voxels, z:batch_z, train:False})
                msgOut = "Epoch: [{0}], i: [{1}], G_Loss[{2:.8f}], D_Loss[{3:.8f}]".format(epoch, i, batch_loss_G, batch_loss_D)

                print(msgOut)

        epoch=epoch+1
        sess.run(global_step.assign(epoch))
        saver.save(sess, "params/epoch{0}.ckpt".format(epoch))

        batch_z = np.random.uniform(-1, 1, [batch_size, z_size]).astype(np.float32)
        voxels = sess.run(x_, feed_dict={z:batch_z})

        v = voxels[0].reshape([32, 32, 32]) > 0
        util.save_binvox(v, "out/epoch{0}.vox".format(epoch), 32)