从 tensorflow hub 加载模型时缺少可训练参数

Missing trainable parameter when loading model from tensorflow hub

我正在将我们的代码从 tensorflow 1 迁移到 tensorflow 2。其中一层是嵌入层加载如下:

import tensorflow_hub as hub
model_url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/1"
self.use_embed = hub.Module(model_url, trainable=False)

在 Tensorflow 2 中,这将变为

import tensorflow_hub as hub
model_url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
self.use_embed = hub.load(model_url)

because

The hub.Module API works for TF1 only. For TF2, switch to plain SavedModels and hub.load().

但是load()方法不支持trainable参数?

这个参数发生了什么变化,我如何在 Tensorflow 2 中应用它?

Model Compatibility Guide 提到参数对于 hub.load()hub.KerasLayer() 有不同的名称:

Use either hub.load:
m = hub.load(handle)
outputs = m(inputs, training=is_training)

or hub.KerasLayer:
m = hub.KerasLayer(handle, trainable=True)
outputs = m(inputs)