加载预训练的 word2vec 以在 Estimator model_fn 中初始化 embedding_lookup
Loading pre-trained word2vec to initialise embedding_lookup in the Estimator model_fn
我正在解决文本 class化问题。我使用 Estimator
class 和我自己的 model_fn
定义了我的 classifier。我想使用 Google 的预训练 word2vec
嵌入作为初始值,然后针对手头的任务进一步优化它。
我看到了这个post:
这解释了如何在 'raw' TensorFlow 代码中进行处理。不过,我真的很想用Estimator
class。
作为扩展,我想在 Cloud ML Engine 上训练这段代码,有没有一种好的方法可以传入具有初始值的相当大的文件?
假设我们有类似的东西:
def build_model_fn():
def _model_fn(features, labels, mode, params):
input_layer = features['feat'] #shape=[-1, params["sequence_length"]]
#... what goes here to initialize W
embedded = tf.nn.embedding_lookup(W, input_layer)
...
return predictions
estimator = tf.contrib.learn.Estimator(
model_fn=build_model_fn(),
model_dir=MODEL_DIR,
params=params)
estimator.fit(input_fn=read_data, max_steps=2500)
嵌入通常足够大,唯一可行的方法是使用它们来初始化图中的 tf.Variable
。这将允许您利用分布式参数服务器等。
为此(以及其他任何事情),我建议您使用新的 "core" 估算器,tf.estimator.Estimator
因为这会让事情变得容易得多。
根据您提供的 link 中的答案,并且知道我们想要一个变量而不是常量,我们可以采取以下方法:
(2) 使用 feed dict 初始化变量,或者
(3) 从检查点加载变量
我将首先介绍选项 (3),因为它更简单、更好:
在您的 model_fn
中,只需使用 Tensor
return 通过 tf.contrib.framework.load_variable
调用初始化一个变量。这需要:
- 你的嵌入有一个有效的 TF 检查点
- 您知道检查点内嵌入变量的完全限定名称。
代码非常简单:
def model_fn(mode, features, labels, hparams):
embeddings = tf.Variable(tf.contrib.framework.load_variable(
'gs://my-bucket/word2vec_checkpoints/',
'a/fully/qualified/scope/embeddings'
))
....
return tf.estimator.EstimatorSpec(...)
但是,如果您的嵌入不是由另一个 TF 模型生成的,则此方法对您不起作用,因此选项 (2)。
对于 (2),我们需要使用 tf.train.Scaffold
,它本质上是一个配置对象,其中包含启动 tf.Session
的所有选项(估计器出于多种原因故意隐藏)。
您可以在 tf.train.EstimatorSpec
中指定 Scaffold
,在 model_fn
中指定 return。
我们在 model_fn 中创建了一个占位符,并将其设为
我们的嵌入变量的初始化操作,然后通过 Scaffold
传递一个 init_feed_dict
。例如
def model_fn(mode, features, labels, hparams):
embed_ph = tf.placeholder(
shape=[hparams.vocab_size, hparams.embedding_size],
dtype=tf.float32)
embeddings = tf.Variable(embed_ph)
# Define your model
return tf.estimator.EstimatorSpec(
..., # normal EstimatorSpec args
scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array})
)
这里发生的是 init_feed_dict
将在 运行 时填充 embed_ph
占位符的值,然后允许 embeddings.initialization_op
(占位符的赋值), 至 运行.
我正在解决文本 class化问题。我使用 Estimator
class 和我自己的 model_fn
定义了我的 classifier。我想使用 Google 的预训练 word2vec
嵌入作为初始值,然后针对手头的任务进一步优化它。
我看到了这个post:
这解释了如何在 'raw' TensorFlow 代码中进行处理。不过,我真的很想用Estimator
class。
作为扩展,我想在 Cloud ML Engine 上训练这段代码,有没有一种好的方法可以传入具有初始值的相当大的文件?
假设我们有类似的东西:
def build_model_fn():
def _model_fn(features, labels, mode, params):
input_layer = features['feat'] #shape=[-1, params["sequence_length"]]
#... what goes here to initialize W
embedded = tf.nn.embedding_lookup(W, input_layer)
...
return predictions
estimator = tf.contrib.learn.Estimator(
model_fn=build_model_fn(),
model_dir=MODEL_DIR,
params=params)
estimator.fit(input_fn=read_data, max_steps=2500)
嵌入通常足够大,唯一可行的方法是使用它们来初始化图中的 tf.Variable
。这将允许您利用分布式参数服务器等。
为此(以及其他任何事情),我建议您使用新的 "core" 估算器,tf.estimator.Estimator
因为这会让事情变得容易得多。
根据您提供的 link 中的答案,并且知道我们想要一个变量而不是常量,我们可以采取以下方法:
(2) 使用 feed dict 初始化变量,或者 (3) 从检查点加载变量
我将首先介绍选项 (3),因为它更简单、更好:
在您的 model_fn
中,只需使用 Tensor
return 通过 tf.contrib.framework.load_variable
调用初始化一个变量。这需要:
- 你的嵌入有一个有效的 TF 检查点
- 您知道检查点内嵌入变量的完全限定名称。
代码非常简单:
def model_fn(mode, features, labels, hparams):
embeddings = tf.Variable(tf.contrib.framework.load_variable(
'gs://my-bucket/word2vec_checkpoints/',
'a/fully/qualified/scope/embeddings'
))
....
return tf.estimator.EstimatorSpec(...)
但是,如果您的嵌入不是由另一个 TF 模型生成的,则此方法对您不起作用,因此选项 (2)。
对于 (2),我们需要使用 tf.train.Scaffold
,它本质上是一个配置对象,其中包含启动 tf.Session
的所有选项(估计器出于多种原因故意隐藏)。
您可以在 tf.train.EstimatorSpec
中指定 Scaffold
,在 model_fn
中指定 return。
我们在 model_fn 中创建了一个占位符,并将其设为
我们的嵌入变量的初始化操作,然后通过 Scaffold
传递一个 init_feed_dict
。例如
def model_fn(mode, features, labels, hparams):
embed_ph = tf.placeholder(
shape=[hparams.vocab_size, hparams.embedding_size],
dtype=tf.float32)
embeddings = tf.Variable(embed_ph)
# Define your model
return tf.estimator.EstimatorSpec(
..., # normal EstimatorSpec args
scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array})
)
这里发生的是 init_feed_dict
将在 运行 时填充 embed_ph
占位符的值,然后允许 embeddings.initialization_op
(占位符的赋值), 至 运行.