如何在 TF2.0 的 `tf.Module` 中初始化变量
How to initialize variables in a `tf.Module` in TF2.0
在Tensorflow2.0中,我发现我可以通过以下方式初始化模型中的变量
class MyModel(tf.keras.Model):
def __init__(self, *args, kwargs**):
""" some definition here """
self(tf.keras.Input(shape=(3,)))
def call(self, x):
""" some implementation """
但是我做不到
class MyModel(tf.keras.Model):
def __init__(self, *args, kwargs**):
""" some definition here """
self.step(tf.keras.Input(shape=(3,)))
def step(self, x):
""" some implementation """
这会报错
我想做第二个的原因是我试图从 tf.Module
继承 MyModel
,它没有可用的 __call__
--- 即使我定义了一个,也会出现同样的错误。我想知道是否有一种方法可以像我在第一个代码块中那样初始化从 tf.Module
继承的 class 中的变量?
Keras functional/symbolic API 遗憾的是只兼容 Keras(例如 compile+fit)。
您可以改用 tf.zeros
(例如 self(tf.zeros(input_shape))
),尽管这可能会产生不良副作用(例如影响您的批量规范统计数据)。
如果你想要一个强大的解决方案,你可能需要考虑使用 snt.build(self, input_shape)
[0],它是 Sonnet 2 中的一个实用函数(一个包含一堆常见 tf.Module
的库).
[0] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/build.py#L50
在Tensorflow2.0中,我发现我可以通过以下方式初始化模型中的变量
class MyModel(tf.keras.Model):
def __init__(self, *args, kwargs**):
""" some definition here """
self(tf.keras.Input(shape=(3,)))
def call(self, x):
""" some implementation """
但是我做不到
class MyModel(tf.keras.Model):
def __init__(self, *args, kwargs**):
""" some definition here """
self.step(tf.keras.Input(shape=(3,)))
def step(self, x):
""" some implementation """
这会报错
tf.Module
继承 MyModel
,它没有可用的 __call__
--- 即使我定义了一个,也会出现同样的错误。我想知道是否有一种方法可以像我在第一个代码块中那样初始化从 tf.Module
继承的 class 中的变量?
Keras functional/symbolic API 遗憾的是只兼容 Keras(例如 compile+fit)。
您可以改用 tf.zeros
(例如 self(tf.zeros(input_shape))
),尽管这可能会产生不良副作用(例如影响您的批量规范统计数据)。
如果你想要一个强大的解决方案,你可能需要考虑使用 snt.build(self, input_shape)
[0],它是 Sonnet 2 中的一个实用函数(一个包含一堆常见 tf.Module
的库).
[0] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/build.py#L50