一次性初始化一个 tensorflow hub 模块变量和表,以便在 restful api 后端中使用

initializing a tensorflow hub module variables and tables once for all that have speed for being used in restful api backend

大家好,我有一个问题,我找不到最好的方法

我有一个后端 restful api 我想在其中使用一个 tensorflow hub 模块但是我有一个问题,那就是每次我想做一个计算我必须初始化所有变量和表格,处理起来花了很多时间,我的问题是:

有没有一种方法可以在一个会话中一次性初始化所有变量和表并关闭会话,因为解决这个问题的方法是保持打开会话并进行计算,但我的解决方案存在问题就是占用资源

我把主要代码和我自己的解决方案都放在一起以便更好地理解

加载不同模块的函数

def loading_module(path = None, module_url = 
                   'https://tfhub.dev/google/universal-sentence-encoder/2'):
    # Import the Universal Sentence Encoder's TF Hub module
    graph = tf.get_default_graph()
    if path == None:
        embed_object = hub.Module(module_url)
    else:
        embed_object = hub.Module(hub.load_module_spec(path))
    return embed_object

运行 文本嵌入模块的功能

def run_embedding(embed_object, graph, text):
    # Reduce logging output.
    tf.logging.set_verbosity(tf.logging.ERROR)
    with tf.Session(graph = graph) as sess:
        sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
        similarity_input_placeholder = tf.placeholder(tf.string, shape=(None))
        encoding_tensor = embed_object(similarity_input_placeholder)
        message_embeddings = sess.run(encoding_tensor, feed_dict = {similarity_input_placeholder:text})

    return message_embeddings

embed_object = loading_module()
run_embedding(embed_object, ['sth'])

我的解决方案

def loading_module(path = None, module_url = 'https://tfhub.dev/google/universal-sentence-encoder/2'):
    # Import the Universal Sentence Encoder's TF Hub module
    g = tf.Graph()
    with g.as_default():
        if path == None:
            embed_object = hub.Module(module_url)
        else:
            embed_object = hub.Module(hub.load_module_spec(path))
    sess = tf.InteractiveSession(graph = g)
    sess.run([tf.global_variables_initializer(), tf.tables_initializer()])

    return embed_object, g, sess


def run_embedding(embed_object, graph, sess, text):
    # Reduce logging output.
    tf.logging.set_verbosity(tf.logging.ERROR)
    with graph.as_default():
        similarity_input_placeholder = tf.placeholder(tf.string, shape=(None))
        encoding_tensor = embed_object(similarity_input_placeholder)
        message_embeddings = sess.run(encoding_tensor, feed_dict = {similarity_input_placeholder:text})

    return message_embeddings

您应该将图形构建和会话构建与会话执行分开。例如:

def make_embed_fn(module):
  with tf.Graph().as_default():
    sentences = tf.placeholder(tf.string)
    embed = hub.Module(module)
    embeddings = embed(sentences)
    session = tf.train.MonitoredSession()
  return lambda x: session.run(embeddings, {sentences: x})

embed_fn = make_embed_fn('https://tfhub.dev/google/universal-sentence-encoder/2')
embed_fn(["hello 1"])
embed_fn(["hello 2"])
embed_fn(["hello 3"])
...

此外,请注意 hub.Module() 和 hub.load_module_spec 都可以通过路径和 https url 调用,您不需要像在您的原文loading_module。例如:

# These two are valid uses of the API:
hub.Module("/tmp/my_local_module")
hub.Module("https://tfhub.dev/...")