如何从现有模型的层的入站节点中删除 training=True?
How to remove training=True from the inbound nodes of a layer in an existing model?
假设有一个模型作为 h5
文件给出,即我无法更改构建模型架构的代码:
from tensorflow.keras.layers import Input, BatchNormalization
from tensorflow.keras.models import Model
inputs = Input(shape=(4,))
outputs = BatchNormalization()(inputs, training=True)
model = Model(inputs=inputs, outputs=outputs)
model.save('model.h5', include_optimizer=False)
现在我想删除 training=True
部分,即,我希望 BatchNormalization
就像它附加到没有此标志的模型一样。
我目前的尝试如下:
import numpy as np
from tensorflow.keras.models import load_model
model = load_model('model.h5')
for layer in model.layers:
for node in layer.inbound_nodes:
if "training" in node.call_kwargs:
del node.call_kwargs["training"]
model.predict(np.asarray([[1, 2, 3, 4]]))
但是 model.predict
调用失败并出现以下错误(我使用的是 TensorFlow 2.5.0
):
ValueError: Could not pack sequence. Structure had 1 elements, but flat_sequence had 2 elements. Structure: ((<KerasTensor: shape=(None, 4) dtype=float32 (created by layer 'input_1')>,), {}), flat_sequence: [<tf.Tensor 'model/Cast:0' shape=(None, 4) dtype=float32>, True].
怎么会这样fixed/worked?
(当使用 node.call_kwargs["training"] = False
而不是 del node.call_kwargs["training"]
时,model.predict
不会崩溃,但它的行为就像什么都没有改变一样,即忽略修改后的标志。)
你试过了吗
for layer in model.layers:
layer.trainable=False
我发现,只需在修改后再次保存并重新加载模型 call_kwargs
有帮助。
import numpy as np
from tensorflow.keras.models import load_model
model = load_model('model.h5')
# Removing training=True
for layer in model.layers:
for node in layer.inbound_nodes:
if "training" in node.call_kwargs:
del node.call_kwargs["training"]
# The two following lines are the solution.
model.save('model_modified.h5', include_optimizer=False)
model = load_model('model_modified.h5')
model.predict(np.asarray([[1, 2, 3, 4]]))
一切都很好。 :)
假设有一个模型作为 h5
文件给出,即我无法更改构建模型架构的代码:
from tensorflow.keras.layers import Input, BatchNormalization
from tensorflow.keras.models import Model
inputs = Input(shape=(4,))
outputs = BatchNormalization()(inputs, training=True)
model = Model(inputs=inputs, outputs=outputs)
model.save('model.h5', include_optimizer=False)
现在我想删除 training=True
部分,即,我希望 BatchNormalization
就像它附加到没有此标志的模型一样。
我目前的尝试如下:
import numpy as np
from tensorflow.keras.models import load_model
model = load_model('model.h5')
for layer in model.layers:
for node in layer.inbound_nodes:
if "training" in node.call_kwargs:
del node.call_kwargs["training"]
model.predict(np.asarray([[1, 2, 3, 4]]))
但是 model.predict
调用失败并出现以下错误(我使用的是 TensorFlow 2.5.0
):
ValueError: Could not pack sequence. Structure had 1 elements, but flat_sequence had 2 elements. Structure: ((<KerasTensor: shape=(None, 4) dtype=float32 (created by layer 'input_1')>,), {}), flat_sequence: [<tf.Tensor 'model/Cast:0' shape=(None, 4) dtype=float32>, True].
怎么会这样fixed/worked?
(当使用 node.call_kwargs["training"] = False
而不是 del node.call_kwargs["training"]
时,model.predict
不会崩溃,但它的行为就像什么都没有改变一样,即忽略修改后的标志。)
你试过了吗
for layer in model.layers:
layer.trainable=False
我发现,只需在修改后再次保存并重新加载模型 call_kwargs
有帮助。
import numpy as np
from tensorflow.keras.models import load_model
model = load_model('model.h5')
# Removing training=True
for layer in model.layers:
for node in layer.inbound_nodes:
if "training" in node.call_kwargs:
del node.call_kwargs["training"]
# The two following lines are the solution.
model.save('model_modified.h5', include_optimizer=False)
model = load_model('model_modified.h5')
model.predict(np.asarray([[1, 2, 3, 4]]))
一切都很好。 :)