copy/clone 个具有自定义属性的 keras 子类模型

copy/clone of keras subclassed models with custom attributes

我有一个带有一些自定义属性的子类模型,如下所示:

class MyModel(tf.keras.Model):
    def __init__(self, *args, my_var, **kwargs):
        super().__init__(*args, **kwargs)
        self.my_var = my_var

    def my_func(self):
        pass

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "my_var": self.my_var
            }
        )
        return config

现在我定义模型并使用 clone_model

克隆它
x_in = layers.Input(shape=(100, 100, 3))
x_out = layers.Conv2D(filters=16, kernel_size=3, activation="relu")(x_in)

model = MyModel(inputs=x_in, outputs=x_out, my_var="my_var")

cloned = tf.keras.models.clone_model(model)
print(cloned.my_var)

模型复制成功,但没有my_var

有什么方法可以正确复制具有所有属性(my_var 和 my_func)的此类模型?

您需要添加

cloned = model.__class__.from_config(model.get_config())

如文档中所示 https://www.tensorflow.org/api_docs/python/tf/keras/models/clone_model#example