如何在 C++ 中保存和恢复 TensorFlow 图及其状态?
How to save and restore a TensorFlow graph and its state in C++?
我正在使用 C++ 中的 TensorFlow 训练我的模型。 Python 仅用于构建图形。那么有没有一种方法可以纯粹在 C++ 中保存和恢复图形及其状态?我知道 Python class tf.train.Saver
但据我所知它在 C++ 中不存在。
tf.train.Saver
class 目前仅存在于 Python, 但 (i) 它是从 TensorFlow ops 构建的,您可以 运行 来自 C++,并且 (ii) 它公开了带有操作名称的 Saver.as_saver_def()
method that lets you get a SaverDef
protocol buffer,您必须 运行 保存或恢复模型。
在Python中,您可以获取保存和恢复操作的名称,如下所示:
saver = tf.train.Saver(...)
saver_def = saver.as_saver_def()
# The name of the tensor you must feed with a filename when saving/restoring.
print saver_def.filename_tensor_name
# The name of the target operation you must run when restoring.
print saver_def.restore_op_name
# The name of the target operation you must run when saving.
print saver_def.save_tensor_name
在 C++ 中,要从检查点恢复,您可以调用 Session::Run()
,输入检查点文件的名称 saver_def.filename_tensor_name
,目标操作为 saver_def.restore_op_name
。要保存另一个检查点,您调用 Session::Run()
,再次输入检查点文件的名称 saver_def.filename_tensor_name
,并获取 saver_def.save_tensor_name
.
的值
最新的 TensorFlow 版本包含一些辅助函数,可以在没有 Python 的情况下在 C++ 中执行相同的操作。这些是从 pip 包中的 ProtoBuf 生成的 (${HOME}/.local/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/protobuf/saver.pb.h
)。
// save
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().save_tensor_name()}, nullptr);
// restore
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().restore_op_name()}, nullptr);
这是基于恢复模型的未记录的 python 方式 (more details)
def restore(sess, metaGraph, fn):
restore_op_name = metaGraph.as_saver_def().restore_op_name # u'save/restore_all'
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name # u'save/Const'
sess.run(restore_op, {filename_tensor_name: fn})
对于工作和完整 version see here。
我正在使用 C++ 中的 TensorFlow 训练我的模型。 Python 仅用于构建图形。那么有没有一种方法可以纯粹在 C++ 中保存和恢复图形及其状态?我知道 Python class tf.train.Saver
但据我所知它在 C++ 中不存在。
tf.train.Saver
class 目前仅存在于 Python, 但 (i) 它是从 TensorFlow ops 构建的,您可以 运行 来自 C++,并且 (ii) 它公开了带有操作名称的 Saver.as_saver_def()
method that lets you get a SaverDef
protocol buffer,您必须 运行 保存或恢复模型。
在Python中,您可以获取保存和恢复操作的名称,如下所示:
saver = tf.train.Saver(...)
saver_def = saver.as_saver_def()
# The name of the tensor you must feed with a filename when saving/restoring.
print saver_def.filename_tensor_name
# The name of the target operation you must run when restoring.
print saver_def.restore_op_name
# The name of the target operation you must run when saving.
print saver_def.save_tensor_name
在 C++ 中,要从检查点恢复,您可以调用 Session::Run()
,输入检查点文件的名称 saver_def.filename_tensor_name
,目标操作为 saver_def.restore_op_name
。要保存另一个检查点,您调用 Session::Run()
,再次输入检查点文件的名称 saver_def.filename_tensor_name
,并获取 saver_def.save_tensor_name
.
最新的 TensorFlow 版本包含一些辅助函数,可以在没有 Python 的情况下在 C++ 中执行相同的操作。这些是从 pip 包中的 ProtoBuf 生成的 (${HOME}/.local/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/protobuf/saver.pb.h
)。
// save
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().save_tensor_name()}, nullptr);
// restore
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().restore_op_name()}, nullptr);
这是基于恢复模型的未记录的 python 方式 (more details)
def restore(sess, metaGraph, fn):
restore_op_name = metaGraph.as_saver_def().restore_op_name # u'save/restore_all'
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name # u'save/Const'
sess.run(restore_op, {filename_tensor_name: fn})
对于工作和完整 version see here。