Tensorflow:如何将 .meta、.data 和 .index 模型文件转换为一个 graph.pb 文件
Tensorflow: How to convert .meta, .data and .index model files into one graph.pb file
在tensorflow中从头开始训练产生了以下6个文件:
- events.out.tfevents.1503494436.06L7-BRM738
- model.ckpt-22480.meta
- checkpoint
- model.ckpt-22480.data-00000-of-00001
- model.ckpt-22480.index
- graph.pbtxt
我想将它们(或仅需要的)转换成一个文件 graph.pb 以便能够将其传输到我的 Android 应用程序.
我尝试了脚本 freeze_graph.py
,但它已经需要 input.pb 文件作为输入,而我没有。 (我只有前面提到的这6个文件)。如何继续获取这个 freezed_graph.pb 文件?我看到了几个线程,但 none 对我有用。
您可以使用这个简单的脚本来做到这一点。但是您必须指定输出节点的名称。
import tensorflow as tf
meta_path = 'model.ckpt-22480.meta' # Your .meta file
output_node_names = ['output:0'] # Output nodes
with tf.Session() as sess:
# Restore the graph
saver = tf.train.import_meta_graph(meta_path)
# Load weights
saver.restore(sess,tf.train.latest_checkpoint('path/of/your/.meta/file'))
# Freeze the graph
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
output_node_names)
# Save the frozen graph
with open('output_graph.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
如果不知道输出节点的名称,有两种方法
您可以浏览图表并使用 Netron or with console summarize_graph 实用程序查找名称。
您可以使用所有节点作为输出节点,如下所示。
output_node_names = [n.name for n in tf.get_default_graph().as_graph_def().node]
(请注意,您必须将此行放在 convert_variables_to_constants
调用之前。)
但我认为这是不寻常的情况,因为如果你不知道输出节点,你就无法实际使用该图。
因为它可能对其他人有帮助,所以我也在 github 的回答之后在这里回答;-)。
我想你可以尝试这样的事情(使用 tensorflow/python/tools 中的 freeze_graph 脚本):
python freeze_graph.py --input_graph=/path/to/graph.pbtxt --input_checkpoint=/path/to/model.ckpt-22480 --input_binary=false --output_graph=/path/to/frozen_graph.pb --output_node_names="the nodes that you want to output e.g. InceptionV3/Predictions/Reshape_1 for Inception V3 "
这里重要的标志是 --input_binary=false,因为文件 graph.pbtxt 是文本格式。我认为它对应于所需的 graph.pb ,这在二进制格式中是等效的。
关于 output_node_names,这让我很困惑,因为我在这部分仍然有一些问题,但是你可以使用 tensorflow 中的 summarize_graph 脚本,它可以将 pb 或 pbtxt 作为输入。
此致,
斯蒂芬
我尝试了 freezed_graph.py 脚本,但 output_node_name 参数完全令人困惑。作业失败。
所以我尝试了另一个:export_inference_graph.py。
它按预期工作!
python -u /tfPath/models/object_detection/export_inference_graph.py \
--input_type=image_tensor \
--pipeline_config_path=/your/config/path/ssd_mobilenet_v1_pets.config \
--trained_checkpoint_prefix=/your/checkpoint/path/model.ckpt-50000 \
--output_directory=/output/path
我使用的tensorflow安装包来自这里:
https://github.com/tensorflow/models
首先,使用以下代码生成graph.pb文件。
用 tf.Session() 作为 sess:
# Restore the graph
_ = tf.train.import_meta_graph(args.input)
# save graph file
g = sess.graph
gdef = g.as_graph_def()
tf.train.write_graph(gdef, ".", args.output, True)
然后,使用summarize graph得到输出节点名。
最后,使用
python freeze_graph.py --input_graph=/path/to/graph.pbtxt --input_checkpoint=/path/to/model.ckpt-22480 --input_binary=false --output_graph=/path/to/frozen_graph.pb --output_node_names="the nodes that you want to output e.g. InceptionV3/Predictions/Reshape_1 for Inception V3 "
生成冻结图。
在tensorflow中从头开始训练产生了以下6个文件:
- events.out.tfevents.1503494436.06L7-BRM738
- model.ckpt-22480.meta
- checkpoint
- model.ckpt-22480.data-00000-of-00001
- model.ckpt-22480.index
- graph.pbtxt
我想将它们(或仅需要的)转换成一个文件 graph.pb 以便能够将其传输到我的 Android 应用程序.
我尝试了脚本 freeze_graph.py
,但它已经需要 input.pb 文件作为输入,而我没有。 (我只有前面提到的这6个文件)。如何继续获取这个 freezed_graph.pb 文件?我看到了几个线程,但 none 对我有用。
您可以使用这个简单的脚本来做到这一点。但是您必须指定输出节点的名称。
import tensorflow as tf
meta_path = 'model.ckpt-22480.meta' # Your .meta file
output_node_names = ['output:0'] # Output nodes
with tf.Session() as sess:
# Restore the graph
saver = tf.train.import_meta_graph(meta_path)
# Load weights
saver.restore(sess,tf.train.latest_checkpoint('path/of/your/.meta/file'))
# Freeze the graph
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
output_node_names)
# Save the frozen graph
with open('output_graph.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
如果不知道输出节点的名称,有两种方法
您可以浏览图表并使用 Netron or with console summarize_graph 实用程序查找名称。
您可以使用所有节点作为输出节点,如下所示。
output_node_names = [n.name for n in tf.get_default_graph().as_graph_def().node]
(请注意,您必须将此行放在 convert_variables_to_constants
调用之前。)
但我认为这是不寻常的情况,因为如果你不知道输出节点,你就无法实际使用该图。
因为它可能对其他人有帮助,所以我也在 github 的回答之后在这里回答;-)。 我想你可以尝试这样的事情(使用 tensorflow/python/tools 中的 freeze_graph 脚本):
python freeze_graph.py --input_graph=/path/to/graph.pbtxt --input_checkpoint=/path/to/model.ckpt-22480 --input_binary=false --output_graph=/path/to/frozen_graph.pb --output_node_names="the nodes that you want to output e.g. InceptionV3/Predictions/Reshape_1 for Inception V3 "
这里重要的标志是 --input_binary=false,因为文件 graph.pbtxt 是文本格式。我认为它对应于所需的 graph.pb ,这在二进制格式中是等效的。
关于 output_node_names,这让我很困惑,因为我在这部分仍然有一些问题,但是你可以使用 tensorflow 中的 summarize_graph 脚本,它可以将 pb 或 pbtxt 作为输入。
此致,
斯蒂芬
我尝试了 freezed_graph.py 脚本,但 output_node_name 参数完全令人困惑。作业失败。
所以我尝试了另一个:export_inference_graph.py。 它按预期工作!
python -u /tfPath/models/object_detection/export_inference_graph.py \
--input_type=image_tensor \
--pipeline_config_path=/your/config/path/ssd_mobilenet_v1_pets.config \
--trained_checkpoint_prefix=/your/checkpoint/path/model.ckpt-50000 \
--output_directory=/output/path
我使用的tensorflow安装包来自这里: https://github.com/tensorflow/models
首先,使用以下代码生成graph.pb文件。 用 tf.Session() 作为 sess:
# Restore the graph
_ = tf.train.import_meta_graph(args.input)
# save graph file
g = sess.graph
gdef = g.as_graph_def()
tf.train.write_graph(gdef, ".", args.output, True)
然后,使用summarize graph得到输出节点名。 最后,使用
python freeze_graph.py --input_graph=/path/to/graph.pbtxt --input_checkpoint=/path/to/model.ckpt-22480 --input_binary=false --output_graph=/path/to/frozen_graph.pb --output_node_names="the nodes that you want to output e.g. InceptionV3/Predictions/Reshape_1 for Inception V3 "
生成冻结图。