在 Tensorflow 中保存和恢复经过训练的 LSTM
Saving and Restoring a trained LSTM in Tensor Flow
我使用 BasicLSTMCell 训练了一个 LSTM 分类器。如何保存我的模型并将其恢复以用于以后的分类?
您可以实例化一个 tf.train.Saver
对象并在训练期间调用 save
传递当前会话和输出检查点文件 (*.ckpt) 路径。您可以在您认为合适的任何时候调用 save
(例如,每隔几个 epoch,当验证错误下降时):
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in file: %s" % save_path)
在 classification/inference 期间,您实例化另一个 tf.train.Saver
并调用 restore
传递当前会话和要恢复的检查点文件。您可以在使用模型进行分类之前调用 restore
,方法是调用 session.run
:
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Do some work with the model
...
参考:https://www.tensorflow.org/versions/r0.11/how_tos/variables/index.html#saving-and-restoring
保存和恢复模型的最简单方法是使用tf.train.Saver
对象。构造函数为图中的所有变量或指定列表的变量添加保存和恢复操作。保护程序对象为 运行 这些操作提供方法,指定要写入或读取的检查点文件的路径。
参考:
https://www.tensorflow.org/versions/r0.11/how_tos/variables/index.html
检查点文件
变量保存在二进制文件中,大致包含从变量名到张量值的映射。
创建 Saver 对象时,您可以选择为检查点文件中的变量命名。默认情况下,它为每个变量使用 Variable.name 属性 的值。
要了解检查点中有哪些变量,您可以使用 inspect_checkpoint 库,尤其是 print_tensors_in_checkpoint_file 函数。
保存变量
使用 tf.train.Saver() 创建一个 Saver 来管理模型中的所有变量。
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in file: %s" % save_path)
恢复变量
同一个Saver对象用于恢复变量。请注意,当您从文件中恢复变量时,您不必事先初始化它们。
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Do some work with the model
...
我自己也在想这个。正如其他人指出的那样,在 TensorFlow 中保存模型的常用方法是使用 tf.train.Saver()
,但我相信这会保存 tf.Variables
的值。
我不确定 BasicLSTMCell
实现中是否有 tf.Variables
会在您执行此操作时自动保存,或者是否可能需要采取其他步骤,但如果所有其他方法均失败, BasicLSTMCell
可以很容易地保存和加载到 pickle 文件中。
我们发现了同样的问题。我们不确定是否保存了内部变量。我们发现您必须在创建/定义 BasicLSTMCell 之后创建保护程序。否则不保存。
是的,LSTM 单元内有权重和偏差变量(事实上,所有神经网络单元都必须在某处有权重变量)。正如其他答案中已经指出的那样,使用 Saver 对象似乎是要走的路......以一种相当方便的方式保存你的变量和你的(元)图。如果你想恢复整个模型,你将需要元图,而不仅仅是一些 tf.Variables 孤立地坐在那里。它确实需要知道它必须保存的所有变量,所以在创建图形后创建保存程序。
处理任何 "is there variables?"/"is it properly reusing weights?"/"how can I actually look at the weights in my LSTM, which isn't bound to any python var?"/等时有用的小技巧。情况是这个小片段:
for i in tf.global_variables():
print(i)
对于变量和
for i in my_graph.get_operations():
print (i)
用于操作。如果您想查看未绑定到 python var 的张量,
tf.Graph.get_tensor_by_name('name_of_op:N')
其中 op 的名称是生成张量的操作的名称,N 是您所追求的(可能是几个)输出张量的索引。
tensorboard 的图形显示可以帮助查找操作名称,如果您的图形有大量操作...最倾向于...
我已经为 LSTM 保存和恢复制作了示例代码。
我也花了很多时间来解决这个问题。
参考这个 url : https://github.com/MareArts/rnn_save_restore_test
希望对这段代码有所帮助。
我使用 BasicLSTMCell 训练了一个 LSTM 分类器。如何保存我的模型并将其恢复以用于以后的分类?
您可以实例化一个 tf.train.Saver
对象并在训练期间调用 save
传递当前会话和输出检查点文件 (*.ckpt) 路径。您可以在您认为合适的任何时候调用 save
(例如,每隔几个 epoch,当验证错误下降时):
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in file: %s" % save_path)
在 classification/inference 期间,您实例化另一个 tf.train.Saver
并调用 restore
传递当前会话和要恢复的检查点文件。您可以在使用模型进行分类之前调用 restore
,方法是调用 session.run
:
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Do some work with the model
...
参考:https://www.tensorflow.org/versions/r0.11/how_tos/variables/index.html#saving-and-restoring
保存和恢复模型的最简单方法是使用tf.train.Saver
对象。构造函数为图中的所有变量或指定列表的变量添加保存和恢复操作。保护程序对象为 运行 这些操作提供方法,指定要写入或读取的检查点文件的路径。
参考:
https://www.tensorflow.org/versions/r0.11/how_tos/variables/index.html
检查点文件
变量保存在二进制文件中,大致包含从变量名到张量值的映射。
创建 Saver 对象时,您可以选择为检查点文件中的变量命名。默认情况下,它为每个变量使用 Variable.name 属性 的值。
要了解检查点中有哪些变量,您可以使用 inspect_checkpoint 库,尤其是 print_tensors_in_checkpoint_file 函数。
保存变量
使用 tf.train.Saver() 创建一个 Saver 来管理模型中的所有变量。
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in file: %s" % save_path)
恢复变量
同一个Saver对象用于恢复变量。请注意,当您从文件中恢复变量时,您不必事先初始化它们。
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
print("Model restored.")
# Do some work with the model
...
我自己也在想这个。正如其他人指出的那样,在 TensorFlow 中保存模型的常用方法是使用 tf.train.Saver()
,但我相信这会保存 tf.Variables
的值。
我不确定 BasicLSTMCell
实现中是否有 tf.Variables
会在您执行此操作时自动保存,或者是否可能需要采取其他步骤,但如果所有其他方法均失败, BasicLSTMCell
可以很容易地保存和加载到 pickle 文件中。
我们发现了同样的问题。我们不确定是否保存了内部变量。我们发现您必须在创建/定义 BasicLSTMCell 之后创建保护程序。否则不保存。
是的,LSTM 单元内有权重和偏差变量(事实上,所有神经网络单元都必须在某处有权重变量)。正如其他答案中已经指出的那样,使用 Saver 对象似乎是要走的路......以一种相当方便的方式保存你的变量和你的(元)图。如果你想恢复整个模型,你将需要元图,而不仅仅是一些 tf.Variables 孤立地坐在那里。它确实需要知道它必须保存的所有变量,所以在创建图形后创建保存程序。
处理任何 "is there variables?"/"is it properly reusing weights?"/"how can I actually look at the weights in my LSTM, which isn't bound to any python var?"/等时有用的小技巧。情况是这个小片段:
for i in tf.global_variables():
print(i)
对于变量和
for i in my_graph.get_operations():
print (i)
用于操作。如果您想查看未绑定到 python var 的张量,
tf.Graph.get_tensor_by_name('name_of_op:N')
其中 op 的名称是生成张量的操作的名称,N 是您所追求的(可能是几个)输出张量的索引。
tensorboard 的图形显示可以帮助查找操作名称,如果您的图形有大量操作...最倾向于...
我已经为 LSTM 保存和恢复制作了示例代码。 我也花了很多时间来解决这个问题。 参考这个 url : https://github.com/MareArts/rnn_save_restore_test 希望对这段代码有所帮助。