TensorFlow 2.0 和 TensorFlow Hub:load_module_spec 等效?

TensorFlow 2.0 & TensorFlow Hub: load_module_spec equivalent?

使用 TensorFlow 1.x 和 TensorFlow hub 时,我们可以加载模块的规范来检查预期的输出形状(可能还有其他有用的规范!),如下所示:

spec = hub.load_module_spec("https://tfhub.dev/google/nnlm-en-dim128/1")
shape = spec.get_output_info_dict()['default'].get_shape()

当尝试对 TF 2.0 兼容集线器模块执行相同操作时,我在调用 load_module_spec 时遇到以下错误消息:

Missing implementation that supports: loader(*('/tmp/tfhub_modules/82c4aaf4250ffb09088bd48368ee7fd00e5464fe',), **{})

是否有检查 TF 2.0 集线器模块输出形状的替代方法?

对于 TensorFlow 2,TF Hub 将转为提供 TF2 的原生 object-based SavedModels [doc, RFC]。如果您的文件系统上已经存在,则它们由 tf.saved_model.load() 加载,或者 hub.load() 可选择从 URL 下载。这为您提供了一个恢复的 Trackable 对象,该对象具有一个 __call__ 成员,其行为类似于 @tf.function,这意味着它具有一个或多个具体函数,每个函数都由一个 TF 图支持,并在它们之间基于在 Tensor shapes/dtypes 和 non-Tensor 参数上。

对于当前的 TF2 alpha 版本,如果您知道允许的输入 TensorSpec,您可以深入了解输出,例如:

loaded_model = hub.load("https://tfhub.dev/google/tf2-preview/nnlm-en-dim128/1")
concrete_function = loaded_model.__call__.get_concrete_function(
    tf.TensorSpec((None,), tf.string))
print(concrete_function.output_shapes, ":",
      concrete_function.output_dtypes)

这给了我

(None, 128) : <dtype: 'float32'>