如何在 TF2.0 的 `tf.Module` 中初始化变量

How to initialize variables in a `tf.Module` in TF2.0

在Tensorflow2.0中,我发现我可以通过以下方式初始化模型中的变量

class MyModel(tf.keras.Model):
    def __init__(self, *args, kwargs**):
        """ some definition here """
        self(tf.keras.Input(shape=(3,)))

    def call(self, x):
        """ some implementation """        

但是我做不到

class MyModel(tf.keras.Model):
    def __init__(self, *args, kwargs**):
        """ some definition here """
        self.step(tf.keras.Input(shape=(3,)))

    def step(self, x):
        """ some implementation """        

这会报错 我想做第二个的原因是我试图从 tf.Module 继承 MyModel,它没有可用的 __call__ --- 即使我定义了一个,也会出现同样的错误。我想知道是否有一种方法可以像我在第一个代码块中那样初始化从 tf.Module 继承的 class 中的变量?

Keras functional/symbolic API 遗憾的是只兼容 Keras(例如 compile+fit)。

您可以改用 tf.zeros(例如 self(tf.zeros(input_shape))),尽管这可能会产生不良副作用(例如影响您的批量规范统计数据)。

如果你想要一个强大的解决方案,你可能需要考虑使用 snt.build(self, input_shape) [0],它是 Sonnet 2 中的一个实用函数(一个包含一堆常见 tf.Module 的库).

[0] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/build.py#L50