如何在 python 中加载 tf.saved_model 后获取输入和输出张量
How to grab the input and output tensors after load a tf.saved_model in python
假设我用下面的代码保存了一个模型
tf.saved_model.simple_save(sess, export_dir, in={'input_x': x, 'input_y':y}, out={'output_z':z})
现在我在另一个 python 程序中加载保存的模型作为
with tf.Session() as sess:
tf.saved_model.loader.load(sess, ['serve'], export_dir)
现在的问题是我如何通过在 [=25= 中指定的 'input_x'、'input_y'、'output_z' 键获取 x、y、z 张量的句柄] 调用 simple_save() 方法时的参数?
我在网上找到的唯一解决方案依赖于在创建它们时明确命名 x, y, z 张量,然后使用这些名称从图中检索它们,这似乎是非常多余的,因为我们已经为它们指定了键在调用 simple_save().
tf.saved_model.loader.load
的 return 值是一个 MetaGraphDef
协议缓冲区,它应该具有您在保存已保存模型时设置的所有签名;这些应该包含你想要的名字。
我确实遇到了你的问题,经过一些调查(我认为 TF 文档很差)我找到了下一个解决方案:
使用返回的 MetaGraphDef 对象查找您的输入\输出名称映射。
graph = tf.Graph()
with graph.as_default():
metagraph = tf.saved_model.loader.load(sess, [tag_constants.SERVING],save_path)
inputs_mapping = dict(metagraph.signature_def['serving_default'].inputs)
outputs_mapping = dict(metagraph.signature_def['serving_default'].outputs)
此代码将为您提供保存到“TensorInfo”对象时提供的名称之间的映射,您可以从他那里轻松获取映射的张量名称,例如:
my_input = inputs_mapping['my_input_name'].name
my_input_t = graph.get_tensor_by_name(my_input)
假设我用下面的代码保存了一个模型
tf.saved_model.simple_save(sess, export_dir, in={'input_x': x, 'input_y':y}, out={'output_z':z})
现在我在另一个 python 程序中加载保存的模型作为
with tf.Session() as sess:
tf.saved_model.loader.load(sess, ['serve'], export_dir)
现在的问题是我如何通过在 [=25= 中指定的 'input_x'、'input_y'、'output_z' 键获取 x、y、z 张量的句柄] 调用 simple_save() 方法时的参数?
我在网上找到的唯一解决方案依赖于在创建它们时明确命名 x, y, z 张量,然后使用这些名称从图中检索它们,这似乎是非常多余的,因为我们已经为它们指定了键在调用 simple_save().
tf.saved_model.loader.load
的 return 值是一个 MetaGraphDef
协议缓冲区,它应该具有您在保存已保存模型时设置的所有签名;这些应该包含你想要的名字。
我确实遇到了你的问题,经过一些调查(我认为 TF 文档很差)我找到了下一个解决方案:
使用返回的 MetaGraphDef 对象查找您的输入\输出名称映射。
graph = tf.Graph()
with graph.as_default():
metagraph = tf.saved_model.loader.load(sess, [tag_constants.SERVING],save_path)
inputs_mapping = dict(metagraph.signature_def['serving_default'].inputs)
outputs_mapping = dict(metagraph.signature_def['serving_default'].outputs)
此代码将为您提供保存到“TensorInfo”对象时提供的名称之间的映射,您可以从他那里轻松获取映射的张量名称,例如:
my_input = inputs_mapping['my_input_name'].name
my_input_t = graph.get_tensor_by_name(my_input)