从 tensorflow hub 加载 wiki40b 嵌入时出错

Error while loading wiki40b embeddings from tensorflow hub

我正在尝试使用此模块 (https://tfhub.dev/google/wiki40b-lm-nl/1) 通过 KerasLayer 加载它,但不确定为什么会出现此错误。

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text

hub_url = "https://tfhub.dev/google/wiki40b-lm-nl/1"
embed = hub.KerasLayer(hub_url, input_shape=[], 
                   dtype=tf.string)

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-37-4e8ab0d5082c> in <module>()
      5 hub_url = "https://tfhub.dev/google/wiki40b-lm-nl/1"
      6 embed = hub.KerasLayer(hub_url, input_shape=[], 
----> 7                        dtype=tf.string)

1 frames
/usr/local/lib/python3.7/dist-packages/tensorflow_hub/keras_layer.py in _get_callable(self)
    300     if self._signature not in self._func.signatures:
    301       raise ValueError("Unknown signature %s in %s (available signatures: %s)."
--> 302                        % (self._signature, self._handle, self._func.signatures))
    303     f = self._func.signatures[self._signature]
    304     if not callable(f):

ValueError: Unknown signature default in https://tfhub.dev/google/wiki40b-lm-nl/1 (available signatures: _SignatureMap({'neg_log_likelihood': <ConcreteFunction pruned(text) at 0x7F3044A93210>, 'tokenization': <ConcreteFunction pruned(text) at 0x7F3040B7D190>, 'token_neg_log_likelihood': <ConcreteFunction pruned(token) at 0x7F3040D14810>, 'word_embeddings': <ConcreteFunction pruned(text) at 0x7F303D3FF2D0>, 'activations': <ConcreteFunction pruned(text) at 0x7F303D3FFF50>, 'prediction': <ConcreteFunction pruned(mem_4, mem_5, mem_6, mem_7, mem_8, mem_9, mem_10, mem_11, input_tokens, mem_0, mem_1, mem_2, mem_3) at 0x7F303C189090>, 'detokenization': <ConcreteFunction pruned(token_ids) at 0x7F3039860790>, 'token_word_embeddings': <ConcreteFunction pruned(token) at 0x7F3038FC2110>, 'token_activations': <ConcreteFunction pruned(token) at 0x7F303BAF9150>})).

我尝试设置签名signature="word_embeddings", signature_outputs_as_dict=True,但事实证明嵌入不接受字符串作为输入,只接受张量。

TypeError                                 Traceback (most recent call last)
<ipython-input-36-e98cfe451175> in <module>()
----> 1 embed('ik')

5 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _call_with_flat_signature(self, args, kwargs, cancellation_manager)
   1733         raise TypeError("{}: expected argument #{}(zero-based) to be a Tensor; "
   1734                         "got {} ({})".format(self._flat_signature_summary(), i,
-> 1735                                              type(arg).__name__, str(arg)))
   1736     return self._call_flat(args, self.captured_inputs, cancellation_manager)
   1737 

TypeError: pruned(text): expected argument #0(zero-based) to be a Tensor; got str (ik)

我的问题是,如何使用带有 str 的嵌入作为输入,正如他们在模块页面(输入部分)中指出的那样?

将包含在 tf.constant 中的文本传递给 embed() 并设置 output_key 关键字应该可以使它起作用:

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text

embed = hub.KerasLayer("https://tfhub.dev/google/wiki40b-lm-nl/1",
                       signature="word_embeddings",
                       output_key="word_embeddings")
embed(tf.constant("\n_START_ARTICLE_\n1001 vrouwen uit de Nederlandse "
                  "geschiedenis\n_START_SECTION_\nSelectie van vrouwen"
                  "\n_START_PARAGRAPH_\nDe 'oudste' biografie in het boek "
                  "is gewijd aan de beschermheilige"))

(使用 TF 2.4.1 和 tensorflow_hub 0.11.0 测试)