如何在 Tensorflow 2.0 中复制网络

How to copy a network in Tensorflow 2.0

我不确定如何在 Tensorflow 2.0 中通过网络进行复制。在 Tensorflow 1.x 中有很多关于如何做到这一点的答案,但 none 关于 2.0。这两个网络都是通过 tf.keras.Model 的子类创建的,所以我不能使用 tf.keras.models.clone_model 函数。

我尝试了下面列出的不同方法,但 none 似乎有效。

network1 = network2
network1.weights = network2.weights

from copy import copy
network1 = copy(network2)

其中一些方法会引用当前网络,但不会实际复制它。非常感谢我能得到的所有帮助!

假设model_amodel_b是同一个Keras模型的实例化。然后做:

for a, b in zip(model_a.variables, model_b.variables):
  a.assign(b)  # copies the variables of model_b into model_a