将一个模型中保存的单个变量张量恢复为另一个模型中的变量张量 - Tensorflow
Restoring a single variable tensor saved in one model to variable tensor in another model - Tensorflow
运行 在 tensorflow 1.3.0 GPU 上。
我在 TF 中训练了一个模型,并使用以下方法仅保存了一个 varialbe 张量:
embeddings = tf.Variable(tf.random_uniform([4**kmer_len, embedding_size], -0.04, 0.04), name='Embeddings')
more code, variables...
saver = tf.train.Saver({"Embeddings": embeddings}) # saving only embeddings variable
some more code, training model...
saver.save(ses, './embeddings/embedding_mat') # saving the variable
现在,我在不同的文件中有一个不同的模型,我想只恢复保存的单个 embeddings 变量。问题是这个新模型有更多的变量。
现在,当我尝试通过以下方式恢复变量时:
embeddings = tf.Variable(tf.random_uniform([4**kmer_len_emb, embedding_size], -0.04, 0.04), name='Embeddings')
dense1 = tf.layers.dense(inputs=kmer_flattened, units=200, activation=tf.nn.relu, use_bias=True)
ses = tf.Session()
init = tf.global_variables_initializer()
ses.run(init)
saver = tf.train.Saver()
saver.restore(ses, './embeddings/embedding_mat')
我收到“未在检查点中找到”错误。
关于如何处理这个的任何想法?
谢谢
是因为找不到dense1
检查点。试试这个:
all_var = tf.global_variables()
var_to_restore = [v for v in all_var if v.name == 'Embeddings:0']
ses.run(init)
saver = tf.train.Saver(var_to_restore)
saver.restore(ses, './embeddings/embedding_mat')
您必须在该变量上创建 Saver
的实例:
saver = tf.train.Saver(var_list=[embeddings])
这是在告诉您的 Saver
实例只处理 restoring/saving 该图形的特定变量,否则它将尝试 restore/save 图形的所有变量。
运行 在 tensorflow 1.3.0 GPU 上。 我在 TF 中训练了一个模型,并使用以下方法仅保存了一个 varialbe 张量:
embeddings = tf.Variable(tf.random_uniform([4**kmer_len, embedding_size], -0.04, 0.04), name='Embeddings')
more code, variables...
saver = tf.train.Saver({"Embeddings": embeddings}) # saving only embeddings variable
some more code, training model...
saver.save(ses, './embeddings/embedding_mat') # saving the variable
现在,我在不同的文件中有一个不同的模型,我想只恢复保存的单个 embeddings 变量。问题是这个新模型有更多的变量。 现在,当我尝试通过以下方式恢复变量时:
embeddings = tf.Variable(tf.random_uniform([4**kmer_len_emb, embedding_size], -0.04, 0.04), name='Embeddings')
dense1 = tf.layers.dense(inputs=kmer_flattened, units=200, activation=tf.nn.relu, use_bias=True)
ses = tf.Session()
init = tf.global_variables_initializer()
ses.run(init)
saver = tf.train.Saver()
saver.restore(ses, './embeddings/embedding_mat')
我收到“未在检查点中找到”错误。 关于如何处理这个的任何想法? 谢谢
是因为找不到dense1
检查点。试试这个:
all_var = tf.global_variables()
var_to_restore = [v for v in all_var if v.name == 'Embeddings:0']
ses.run(init)
saver = tf.train.Saver(var_to_restore)
saver.restore(ses, './embeddings/embedding_mat')
您必须在该变量上创建 Saver
的实例:
saver = tf.train.Saver(var_list=[embeddings])
这是在告诉您的 Saver
实例只处理 restoring/saving 该图形的特定变量,否则它将尝试 restore/save 图形的所有变量。