如何从现有模型的层的入站节点中删除 training=True?

How to remove training=True from the inbound nodes of a layer in an existing model?

假设有一个模型作为 h5 文件给出,即我无法更改构建模型架构的代码:

from tensorflow.keras.layers import Input, BatchNormalization
from tensorflow.keras.models import Model

inputs = Input(shape=(4,))
outputs = BatchNormalization()(inputs, training=True)
model = Model(inputs=inputs, outputs=outputs)
model.save('model.h5', include_optimizer=False)

现在我想删除 training=True 部分,即,我希望 BatchNormalization 就像它附加到没有此标志的模型一样。

我目前的尝试如下:

import numpy as np
from tensorflow.keras.models import load_model

model = load_model('model.h5')

for layer in model.layers:
    for node in layer.inbound_nodes:
        if "training" in node.call_kwargs:
            del node.call_kwargs["training"]

model.predict(np.asarray([[1, 2, 3, 4]]))

但是 model.predict 调用失败并出现以下错误(我使用的是 TensorFlow 2.5.0):

ValueError: Could not pack sequence. Structure had 1 elements, but flat_sequence had 2 elements.  Structure: ((<KerasTensor: shape=(None, 4) dtype=float32 (created by layer 'input_1')>,), {}), flat_sequence: [<tf.Tensor 'model/Cast:0' shape=(None, 4) dtype=float32>, True].

怎么会这样fixed/worked?

(当使用 node.call_kwargs["training"] = False 而不是 del node.call_kwargs["training"] 时,model.predict 不会崩溃,但它的行为就像什么都没有改变一样,即忽略修改后的标志。)

你试过了吗

for layer in model.layers:
    layer.trainable=False

我发现,只需在修改后再次保存并重新加载模型 call_kwargs 有帮助。

import numpy as np
from tensorflow.keras.models import load_model

model = load_model('model.h5')

# Removing training=True
for layer in model.layers:
    for node in layer.inbound_nodes:
        if "training" in node.call_kwargs:
            del node.call_kwargs["training"]

# The two following lines are the solution.
model.save('model_modified.h5', include_optimizer=False)
model = load_model('model_modified.h5')

model.predict(np.asarray([[1, 2, 3, 4]]))

一切都很好。 :)