如何在 C++ 中保存和恢复 TensorFlow 图及其状态?

How to save and restore a TensorFlow graph and its state in C++?

我正在使用 C++ 中的 TensorFlow 训练我的模型。 Python 仅用于构建图形。那么有没有一种方法可以纯粹在 C++ 中保存和恢复图形及其状态?我知道 Python class tf.train.Saver 但据我所知它在 C++ 中不存在。

tf.train.Saver class 目前仅存在于 Python, (i) 它是从 TensorFlow ops 构建的,您可以 运行 来自 C++,并且 (ii) 它公开了带有操作名称的 Saver.as_saver_def() method that lets you get a SaverDef protocol buffer,您必须 运行 保存或恢复模型。

在Python中,您可以获取保存和恢复操作的名称,如下所示:

saver = tf.train.Saver(...)
saver_def = saver.as_saver_def()

# The name of the tensor you must feed with a filename when saving/restoring.
print saver_def.filename_tensor_name

# The name of the target operation you must run when restoring.
print saver_def.restore_op_name

# The name of the target operation you must run when saving.
print saver_def.save_tensor_name

在 C++ 中,要从检查点恢复,您可以调用 Session::Run(),输入检查点文件的名称 saver_def.filename_tensor_name,目标操作为 saver_def.restore_op_name。要保存另一个检查点,您调用 Session::Run(),再次输入检查点文件的名称 saver_def.filename_tensor_name,并获取 saver_def.save_tensor_name.

的值

最新的 TensorFlow 版本包含一些辅助函数,可以在没有 Python 的情况下在 C++ 中执行相同的操作。这些是从 pip 包中的 ProtoBuf 生成的 (${HOME}/.local/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/protobuf/saver.pb.h)。

// save
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().save_tensor_name()}, nullptr);

// restore
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().restore_op_name()}, nullptr);

这是基于恢复模型的未记录的 python 方式 (more details)

def restore(sess, metaGraph, fn):
    restore_op_name = metaGraph.as_saver_def().restore_op_name   # u'save/restore_all'
    restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
    filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name  # u'save/Const'
    sess.run(restore_op, {filename_tensor_name: fn})

对于工作和完整 version see here