如何在tensorflow中设置keras子类模型的输入?
How to set the input of a keras subclass model in tensorflow?
我使用 tensorflow 创建了一个 keras 子类模型。片段如下所示。
class SubModel(Model):
def call(self, inputs):
print(inputs)
model = SubModel()
model.fit(data, labels, ...)
当 fit
模型时,它将获得输入和 input_shape 本身。我想要做的是将输入传递给模型 myself.Just,就像函数 API 所做的那样。
inputs = tf.keras.input(shape=(100,))
model = tf.keras.Model(inputs=inputs, outputs=outputs)
类似的东西?
model_ = SubModel()
inputs = tf.keras.input(shape=(100,))
outputs = model_(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
我最终放弃了 keras.Model subclassing。这太棘手了,我收到有关输入形状的错误。
我希望能够直接在我的自定义 class 模型对象上使用 .fit()
。
为此,我发现一个简单的方法是实现内置的 __getattr__
方法(更多信息可以在 official Python doc 中找到)。我使用的 class 实现:
from tensorflow.keras import Input, layers, Model
class SubModel():
def __init__(self):
self.model = self.get_model()
def get_model(self):
# here we use the usual Keras functional API
x = Input(shape=(24, 24, 3))
y = layers.Conv2D(28, 3, strides=1)(x)
return Model(inputs=[x], outputs=[y])
def __getattr__(self, name):
"""
This method enables to access an attribute/method of self.model.
Thus, any method of keras.Model() can be used transparently from a SubModel object
"""
return getattr(self.model, name)
if __name__ == '__main__':
submodel = SubModel()
submodel.fit(data, labels, ...) # underlyingly calls SubModel.model.fit()
如果你想在调用model.fit
之前能够指定输入的形状,你可以使用model.build
它需要一个位置参数:input_shape
.
不相关(但以防其他人遇到此问题),只要您想调用 model.summary
或有时使用 Dense 层,就需要这样做,例如:
ValueError: The last dimension of the inputs to "Dense" should be defined. Found "None".
举个例子:
class MyModel(keras.Model):
def __init__(self, input_shape):
super().__init__()
# Example layers that would through an error if we didn't call build
self.convT1 = keras.layers.Conv2DTranspose(filters=1, kernel_size=10)
self.dense = keras.layers.Dense(10)
self.compile(optimizer='Adam', loss='mse', metrics='acc'))
# ! Call build and pass the input_shape
self.build(input_shape)
self.summary() # Because we can now! (would fail without self.build)
model = MyModel(input_shape=(1, 1, 10, 10))
您也可以在初始化后调用 model.build
而不是 self.build
。
我使用 tensorflow 创建了一个 keras 子类模型。片段如下所示。
class SubModel(Model):
def call(self, inputs):
print(inputs)
model = SubModel()
model.fit(data, labels, ...)
当 fit
模型时,它将获得输入和 input_shape 本身。我想要做的是将输入传递给模型 myself.Just,就像函数 API 所做的那样。
inputs = tf.keras.input(shape=(100,))
model = tf.keras.Model(inputs=inputs, outputs=outputs)
类似的东西?
model_ = SubModel()
inputs = tf.keras.input(shape=(100,))
outputs = model_(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
我最终放弃了 keras.Model subclassing。这太棘手了,我收到有关输入形状的错误。
我希望能够直接在我的自定义 class 模型对象上使用 .fit()
。
为此,我发现一个简单的方法是实现内置的 __getattr__
方法(更多信息可以在 official Python doc 中找到)。我使用的 class 实现:
from tensorflow.keras import Input, layers, Model
class SubModel():
def __init__(self):
self.model = self.get_model()
def get_model(self):
# here we use the usual Keras functional API
x = Input(shape=(24, 24, 3))
y = layers.Conv2D(28, 3, strides=1)(x)
return Model(inputs=[x], outputs=[y])
def __getattr__(self, name):
"""
This method enables to access an attribute/method of self.model.
Thus, any method of keras.Model() can be used transparently from a SubModel object
"""
return getattr(self.model, name)
if __name__ == '__main__':
submodel = SubModel()
submodel.fit(data, labels, ...) # underlyingly calls SubModel.model.fit()
如果你想在调用model.fit
之前能够指定输入的形状,你可以使用model.build
它需要一个位置参数:input_shape
.
不相关(但以防其他人遇到此问题),只要您想调用 model.summary
或有时使用 Dense 层,就需要这样做,例如:
ValueError: The last dimension of the inputs to "Dense" should be defined. Found "None".
举个例子:
class MyModel(keras.Model):
def __init__(self, input_shape):
super().__init__()
# Example layers that would through an error if we didn't call build
self.convT1 = keras.layers.Conv2DTranspose(filters=1, kernel_size=10)
self.dense = keras.layers.Dense(10)
self.compile(optimizer='Adam', loss='mse', metrics='acc'))
# ! Call build and pass the input_shape
self.build(input_shape)
self.summary() # Because we can now! (would fail without self.build)
model = MyModel(input_shape=(1, 1, 10, 10))
您也可以在初始化后调用 model.build
而不是 self.build
。