如何直接从会话导出 tf 模型以提供服务(不创建 tf 检查点)以最小化导出时间

how to export tf model for serving directly from session (no creating of tf checkpoint) to minimize export time

我想分享我关于如何导出 tf 模型以直接从会话提供服务而不创建模型检查点的发现。我的用例需要最短时间来创建 pb 文件,因此我想直接从会话中获取 model.pb 文件而不创建模型检查点。

大多数在线示例(和文档指的是创建模型检查点并加载它以创建 tf-serving (pb) 文件的常见情况。当然,如果导出性能时间为不是问题。

import tensorflow as tf
from tensorflow.python.framework import importer
output_path = '/export_directory' # be sure to create it before export
input_ops = ['name/s_of_model_input/s']
output_ops = ['name/s_of_model_output/s']
session = tf.compat.v1.Session()

def get_ops_dict(ops, graph, name='op_'):
    out_dict = dict()
    for i, op in enumerate(ops):
        out_dict[name + str(i)] = tf.compat.v1.saved_model.build_tensor_info(graph.get_tensor_by_name(op + ':0'))
    return out_dict


def add_meta_graph(pbtxt_tmp_path, graph_def):
    with tf.Graph().as_default() as graph:
        importer.import_graph_def(graph_def, name="")
        os.unlink(pbtxt_tmp_path)

        # used to rename model input/outputs
        inputs_dict = get_ops_dict(input_ops, graph, name='input_')
        outputs_dict = get_ops_dict(output_ops, graph, name='output_')

        prediction_signature = (
            tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
            inputs=inputs_dict,
            outputs=outputs_dict,
            method_name=tf.saved_model.PREDICT_METHOD_NAME))

        legacy_init_op = tf.group(tf.compat.v1.tables_initializer(), name='legacy_init_op')
        builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(output_path+'/export')
        builder.add_meta_graph_and_variables(
            session,
            tags=[tf.saved_model.SERVING],
            signature_def_map={
                tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature},
            legacy_init_op=legacy_init_op)
        builder.save()
    return prediction_signature

def export_model(session, output_path, output_ops):
    graph_def = session.graph_def
    tf.io.write_graph(graph_or_graph_def=graph_def, logdir=output_path, 
    name='model.pbtxt', as_text=False)

    frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
        session, graph_def, output_ops)

    prediction_signature = add_meta_graph(output_path+'/model.pbtxt', frozen_graph_def)