从 tensorflow hub 加载模型时缺少可训练参数
Missing trainable parameter when loading model from tensorflow hub
我正在将我们的代码从 tensorflow 1 迁移到 tensorflow 2。其中一层是嵌入层加载如下:
import tensorflow_hub as hub
model_url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/1"
self.use_embed = hub.Module(model_url, trainable=False)
在 Tensorflow 2 中,这将变为
import tensorflow_hub as hub
model_url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
self.use_embed = hub.load(model_url)
The hub.Module API works for TF1 only. For TF2, switch to plain SavedModels and hub.load().
但是load()
方法不支持trainable
参数?
这个参数发生了什么变化,我如何在 Tensorflow 2 中应用它?
Model Compatibility Guide 提到参数对于 hub.load()
和 hub.KerasLayer()
有不同的名称:
Use either hub.load:
m = hub.load(handle)
outputs = m(inputs, training=is_training)
or hub.KerasLayer:
m = hub.KerasLayer(handle, trainable=True)
outputs = m(inputs)
我正在将我们的代码从 tensorflow 1 迁移到 tensorflow 2。其中一层是嵌入层加载如下:
import tensorflow_hub as hub
model_url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/1"
self.use_embed = hub.Module(model_url, trainable=False)
在 Tensorflow 2 中,这将变为
import tensorflow_hub as hub
model_url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
self.use_embed = hub.load(model_url)
The hub.Module API works for TF1 only. For TF2, switch to plain SavedModels and hub.load().
但是load()
方法不支持trainable
参数?
这个参数发生了什么变化,我如何在 Tensorflow 2 中应用它?
Model Compatibility Guide 提到参数对于 hub.load()
和 hub.KerasLayer()
有不同的名称:
Use either hub.load:
m = hub.load(handle)
outputs = m(inputs, training=is_training)
or hub.KerasLayer:
m = hub.KerasLayer(handle, trainable=True)
outputs = m(inputs)