使用 tensorflow 提取 ELMo 特征并将它们转换为 numpy

Extracting ELMo features using tensorflow and convert them to numpy

所以我对使用 ELMo 模型提取句子嵌入很感兴趣。

我一开始试过这个:

import tensorflow as tf
import tensorflow_hub as hub
import numpy as np

elmo_model = hub.Module("https://tfhub.dev/google/elmo/2", trainable=True)

x = ["Hi my friend"]

embeddings = elmo_model(x, signature="default", as_dict=True)["elmo"]


print(embeddings.shape)
print(embeddings.numpy())

它在最后一行之前运行良好,我无法将它转换为 numpy 数组。

我稍微搜索了一下,发现如果我在代码的开头加上下面这行,问题一定会解决。

tf.enable_eager_execution()

但是,我把它放在代码的开头,我意识到我无法编译

elmo_model = hub.Module("https://tfhub.dev/google/elmo/2", trainable=True)

我收到这个错误:

Exporting/importing meta graphs is not supported when eager execution is enabled. No graph exists when eager execution is enabled.

我该如何解决我的问题?我的目标是获取句子特征并在NumPy数组中使用它们。

提前致谢

TF 2.x

TF2 行为更接近于经典的 python 行为,因为它默认为立即执行。但是,您应该使用 hub.load 在 TF2 中加载您的模型。

elmo = hub.load("https://tfhub.dev/google/elmo/2").signatures["default"]
x = ["Hi my friend"]
embeddings = elmo(tf.constant(x))["elmo"]

然后,您可以访问结果并使用 numpy 方法将它们转换为 numpy 数组。

>>> embeddings.numpy()
array([[[-0.7205108 , -0.27990735, -0.7735629 , ..., -0.24703965,
         -0.8358178 , -0.1974785 ],
        [ 0.18500198, -0.12270843, -0.35163105, ...,  0.14234722,
          0.08479916, -0.11709933],
        [-0.49985904, -0.88964033, -0.30124515, ...,  0.15846594,
          0.05210422,  0.25386307]]], dtype=float32)

TF 1.x

如果使用TF 1.x,你应该运行在tf.Session里面操作。 TensorFlow 不使用急切执行,需要先构建图形,然后在会话中评估结果。

elmo_model = hub.Module("https://tfhub.dev/google/elmo/2", trainable=True)
x = ["Hi my friend"]
embeddings_op = elmo_model(x, signature="default", as_dict=True)["elmo"]
# required to load the weights into the graph
init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    embeddings = sess.run(embeddings_op)

在那种情况下,结果将已经是一个 numpy 数组:

>>> embeddings
array([[[-0.72051036, -0.27990723, -0.773563  , ..., -0.24703972,
         -0.83581805, -0.19747877],
        [ 0.18500218, -0.12270836, -0.35163072, ...,  0.14234722,
          0.08479934, -0.11709933],
        [-0.49985906, -0.8896401 , -0.3012453 , ...,  0.15846589,
          0.05210405,  0.2538631 ]]], dtype=float32)