加载预训练的 word2vec 以在 Estimator model_fn 中初始化 embedding_lookup

Loading pre-trained word2vec to initialise embedding_lookup in the Estimator model_fn

我正在解决文本 class化问题。我使用 Estimator class 和我自己的 model_fn 定义了我的 classifier。我想使用 Google 的预训练 word2vec 嵌入作为初始值,然后针对手头的任务进一步优化它。

我看到了这个post:
这解释了如何在 'raw' TensorFlow 代码中进行处理。不过,我真的很想用Estimatorclass。

作为扩展,我想在 Cloud ML Engine 上训练这段代码,有没有一种好的方法可以传入具有初始值的相当大的文件?

假设我们有类似的东西:

def build_model_fn():
    def _model_fn(features, labels, mode, params):
        input_layer = features['feat'] #shape=[-1, params["sequence_length"]]
        #... what goes here to initialize W

        embedded = tf.nn.embedding_lookup(W, input_layer)
        ...
        return predictions

estimator = tf.contrib.learn.Estimator(
    model_fn=build_model_fn(),
    model_dir=MODEL_DIR,
    params=params)
estimator.fit(input_fn=read_data, max_steps=2500)

嵌入通常足够大,唯一可行的方法是使用它们来初始化图中的 tf.Variable。这将允许您利用分布式参数服务器等。

为此(以及其他任何事情),我建议您使用新的 "core" 估算器,tf.estimator.Estimator 因为这会让事情变得容易得多。

根据您提供的 link 中的答案,并且知道我们想要一个变量而不是常量,我们可以采取以下方法:

(2) 使用 feed dict 初始化变量,或者 (3) 从检查点加载变量


我将首先介绍选项 (3),因为它更简单、更好:

在您的 model_fn 中,只需使用 Tensor return 通过 tf.contrib.framework.load_variable 调用初始化一个变量。这需要:

  1. 你的嵌入有一个有效的 TF 检查点
  2. 您知道检查点内嵌入变量的完全限定名称。

代码非常简单:

def model_fn(mode, features, labels, hparams):
  embeddings = tf.Variable(tf.contrib.framework.load_variable(
      'gs://my-bucket/word2vec_checkpoints/',
      'a/fully/qualified/scope/embeddings'
  ))
  ....
  return tf.estimator.EstimatorSpec(...)

但是,如果您的嵌入不是由另一个 TF 模型生成的,则此方法对您不起作用,因此选项 (2)。


对于 (2),我们需要使用 tf.train.Scaffold,它本质上是一个配置对象,其中包含启动 tf.Session 的所有选项(估计器出于多种原因故意隐藏)。

您可以在 tf.train.EstimatorSpec 中指定 Scaffold,在 model_fn 中指定 return。

我们在 model_fn 中创建了一个占位符,并将其设为 我们的嵌入变量的初始化操作,然后通过 Scaffold 传递一个 init_feed_dict。例如

def model_fn(mode, features, labels, hparams):
  embed_ph = tf.placeholder(
      shape=[hparams.vocab_size, hparams.embedding_size], 
      dtype=tf.float32)
  embeddings = tf.Variable(embed_ph)
  # Define your model
  return tf.estimator.EstimatorSpec(
      ..., # normal EstimatorSpec args
      scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array})
  )

这里发生的是 init_feed_dict 将在 运行 时填充 embed_ph 占位符的值,然后允许 embeddings.initialization_op (占位符的赋值), 至 运行.