如何把 .ckpt 转换成 .pb?
How to convert .ckpt to .pb?
我是深度学习的新手,我想使用预训练 (EAST) 模型从 AI Platform Serving 提供服务,开发人员提供了这些文件:
- model.ckpt-49491.data-00000-of-00001
- 检查点
- model.ckpt-49491.index
- model.ckpt-49491.meta
我想将其转换为 TensorFlow .pb
格式。有办法吗?我从 here
中获取了模型
完整代码可用here。
我已经查找了 here,它显示了以下代码来转换它:
来自tensorflow/models/research/
INPUT_TYPE=image_tensor
PIPELINE_CONFIG_PATH={path to pipeline config file}
TRAINED_CKPT_PREFIX={path to model.ckpt}
EXPORT_DIR={path to folder that will be used for export}
python object_detection/export_inference_graph.py \
--input_type=${INPUT_TYPE} \
--pipeline_config_path=${PIPELINE_CONFIG_PATH} \
--trained_checkpoint_prefix=${TRAINED_CKPT_PREFIX} \
--output_directory=${EXPORT_DIR}
我不知道要传递什么值:
- INPUT_TYPE
- PIPELINE_CONFIG_PATH.
这是将检查点转换为 SavedModel 的代码
import os
import tensorflow as tf
trained_checkpoint_prefix = 'models/model.ckpt-49491'
export_dir = os.path.join('export_dir', '0')
graph = tf.Graph()
with tf.compat.v1.Session(graph=graph) as sess:
# Restore from checkpoint
loader = tf.compat.v1.train.import_meta_graph(trained_checkpoint_prefix + '.meta')
loader.restore(sess, trained_checkpoint_prefix)
# Export checkpoint to SavedModel
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess,
[tf.saved_model.TRAINING, tf.saved_model.SERVING],
strip_default_attrs=True)
builder.save()
根据@Puneith Kaul 的回答,这里是 tensorflow 1.7 版的语法:
import os
import tensorflow as tf
export_dir = 'export_dir'
trained_checkpoint_prefix = 'models/model.ckpt'
graph = tf.Graph()
loader = tf.train.import_meta_graph(trained_checkpoint_prefix + ".meta" )
sess = tf.Session()
loader.restore(sess,trained_checkpoint_prefix)
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING, tf.saved_model.tag_constants.SERVING], strip_default_attrs=True)
builder.save()
如果您将 INPUT_TYPE 指定为 image_tensor 并且
PIPELINE_CONFIG_PATH 作为您使用此命令的配置文件。
python object_detection/export_inference_graph.py \
--input_type=${INPUT_TYPE} \
--pipeline_config_path=${PIPELINE_CONFIG_PATH} \
--trained_checkpoint_prefix=${TRAINED_CKPT_PREFIX} \
--output_directory=${EXPORT_DIR}
您可以在导出目录中获得 3 种格式的模型;
- frozen_graph.pb
- savedmodel.pb
- 检查点
了解更多信息https://github.com/tensorflow/models/tree/master/research/object_detection
我是深度学习的新手,我想使用预训练 (EAST) 模型从 AI Platform Serving 提供服务,开发人员提供了这些文件:
- model.ckpt-49491.data-00000-of-00001
- 检查点
- model.ckpt-49491.index
- model.ckpt-49491.meta
我想将其转换为 TensorFlow .pb
格式。有办法吗?我从 here
完整代码可用here。
我已经查找了 here,它显示了以下代码来转换它:
来自tensorflow/models/research/
INPUT_TYPE=image_tensor
PIPELINE_CONFIG_PATH={path to pipeline config file}
TRAINED_CKPT_PREFIX={path to model.ckpt}
EXPORT_DIR={path to folder that will be used for export}
python object_detection/export_inference_graph.py \
--input_type=${INPUT_TYPE} \
--pipeline_config_path=${PIPELINE_CONFIG_PATH} \
--trained_checkpoint_prefix=${TRAINED_CKPT_PREFIX} \
--output_directory=${EXPORT_DIR}
我不知道要传递什么值:
- INPUT_TYPE
- PIPELINE_CONFIG_PATH.
这是将检查点转换为 SavedModel 的代码
import os
import tensorflow as tf
trained_checkpoint_prefix = 'models/model.ckpt-49491'
export_dir = os.path.join('export_dir', '0')
graph = tf.Graph()
with tf.compat.v1.Session(graph=graph) as sess:
# Restore from checkpoint
loader = tf.compat.v1.train.import_meta_graph(trained_checkpoint_prefix + '.meta')
loader.restore(sess, trained_checkpoint_prefix)
# Export checkpoint to SavedModel
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess,
[tf.saved_model.TRAINING, tf.saved_model.SERVING],
strip_default_attrs=True)
builder.save()
根据@Puneith Kaul 的回答,这里是 tensorflow 1.7 版的语法:
import os
import tensorflow as tf
export_dir = 'export_dir'
trained_checkpoint_prefix = 'models/model.ckpt'
graph = tf.Graph()
loader = tf.train.import_meta_graph(trained_checkpoint_prefix + ".meta" )
sess = tf.Session()
loader.restore(sess,trained_checkpoint_prefix)
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING, tf.saved_model.tag_constants.SERVING], strip_default_attrs=True)
builder.save()
如果您将 INPUT_TYPE 指定为 image_tensor 并且 PIPELINE_CONFIG_PATH 作为您使用此命令的配置文件。
python object_detection/export_inference_graph.py \
--input_type=${INPUT_TYPE} \
--pipeline_config_path=${PIPELINE_CONFIG_PATH} \
--trained_checkpoint_prefix=${TRAINED_CKPT_PREFIX} \
--output_directory=${EXPORT_DIR}
您可以在导出目录中获得 3 种格式的模型;
- frozen_graph.pb
- savedmodel.pb
- 检查点
了解更多信息https://github.com/tensorflow/models/tree/master/research/object_detection