tensorflow 中 pytorch module.register_parameter(name, param) 的替代方案是什么?

What is the alternative of pytorch module.register_parameter(name, param) in tensorflow?

我正在尝试将一些 pytorch 代码转换为 tensorflow。在 pytorch 代码中,他们使用 module.register_parameter(name, param) 在模型的每个模块上添加了一些额外的参数。我如何在tensorflow上隐藏这部分代码?

下面的示例代码:

for module_name, module in self.model.named_modules():
    module.register_parameter(name, new_parameter)
     

tf.Variable 相当于 PyTorch 中的 nn.Parametertf.Variable主要用于存储模型参数,因为它们的值在训练过程中不断更新。

要将张量用作新模型参数,您需要将其转换为tf.Variable。您可以查看 here 如何从张量创建变量。

如果您想在模型本身内部的 TensorFlow 中添加模型参数,您只需在模型内部创建一个变量 class,它会被 TensorFlow 自动注册为模型参数。

如果要将 tf.Variable 外部 添加到模型作为模型参数,您可以手动将其添加到 trainable_weights attribute of tf.keras.layers.Layer by extending it like this -

model.layers[-1].trainable_weights.extend([new_parameter])