保存具有多个输出的 Keras 模型(包装字典在包装器外修改)

Saving Keras model with multiple outputs (The wrapped dictionary was modified outside the wrapper)

我有一个 multi-class classifier,我也想查询其中一个中间层的输出。

inputs = Input(...)

...

fc = Dense(32, activation='relu', name='FC_1')(layer)
x = Dense(num_cats, activation='softmax', name='Softmax')(fc)

outputs = {
    'predictions': x,
    'fc': fc,
}

model = Model(inputs=inputs, outputs=outputs)

opt = Adam()
cat_crossentropy = SparseCategoricalCrossentropy()


loss = {
    'predictions': cat_crossentropy,
}


model.compile(optimizer=opt, loss=loss, metrics={'predictions': [
    SparseCategoricalAccuracy(),
    SparseTopKCategoricalAccuracy(k=3),
]})

这在同一过程中效果很好。 model.fit() 执行您期望的操作,model.predict() returns outputs 使用定义的两个键的值。

但是,model.save(output_path) 引发了以下异常:

ValueError: Unable to save the object {
   'predictions': <keras.losses.SparseCategoricalCrossentropy object at 0x15e5c6c70>,
   'dense': None,
}

(a dictionary wrapper constructed automatically on attribute assignment).

The wrapped dictionary was modified outside the wrapper (its final value was

{
   'predictions': <keras.losses.SparseCategoricalCrossentropy object at 0x15e5c6c70>,
   'dense': None
},

its value when a checkpoint dependency was added was None),
which breaks restoration on object creation.

If you don't need this dictionary checkpointed,
wrap it in a non-trackable object; it will be subsequently ignored.

问题

如何导出包含命名输出的已保存模型?

编辑 - 可重现的代码

import tensorflow as tf
import numpy as np

from tensorflow.data import Dataset
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Dense
from tensorflow.keras.metrics import SparseCategoricalAccuracy
from tensorflow.keras.losses import SparseCategoricalCrossentropy


def build_model():
    i = Input(shape=(5,))
    fc = Dense(5, activation='relu')(i)
    softmax = Dense(5, activation='softmax', name='Softmax')(fc)

    outputs = {'predictions': softmax, 'dense': fc}
    return Model(inputs=i, outputs=outputs)


model = build_model()
opt = tf.keras.optimizers.Adam(learning_rate=0.003)

model.compile(optimizer=opt, metrics={'predictions': [SparseCategoricalAccuracy()]}, loss={
    'predictions': SparseCategoricalCrossentropy(from_logits=False)
})

_ds = Dataset.from_tensor_slices((np.random.rand(10, 5), np.random.randint(0, 5, 10)))
model.fit(_ds.batch(32), epochs=3, verbose=0)

print(model.predict([[0.1] * 5]))
# >>> {
#       'predictions': array([[
#                               0.20790772,
#                               0.18233214,
#                               0.18703228,
#                               0.20266992,
#                               0.22005796
#                      ]], dtype=float32),
#       'dense': array([[
#                         0.1394624,
#                         0.16406271,
#                         0.,
#                         0.,
#                         0.
#                      ]], dtype=float32)}

model.save('my_model')
# >>> Traceback (most recent call last): ...

这对我有用,如果对你有用,请告诉我。 帮我弄明白了。

import tensorflow as tf                                                                                                                                                                                                                                                                                                         
import numpy as np                                                                                                                                                                                                                                                                                                              
                                                                                                                                                                                                                                                                                                                                
from tensorflow.data import Dataset                                                                                                                                                                                                                                                                                             
from tensorflow.keras import Model, Input                                                                                                                                                                                                                                                                                       
from tensorflow.keras.layers import Dense                                                                                                                                                                                                                                                                                       
from tensorflow.keras.metrics import SparseCategoricalAccuracy                                                                                                                                                                                                                                                                  
from tensorflow.keras.losses import SparseCategoricalCrossentropy                                                                                                                                                                                                                                                               
                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                
def build_model():                                                                                                                                                                                                                                                                                                              
    i = Input(shape=(5,))                                                                                                                                                                                                                                                                                                       
    fc = Dense(5, activation='relu')(i)                                                                                                                                                                                                                                                                                         
    softmax = Dense(5, activation='softmax', name='Softmax')(fc)                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                
    # outputs = {'predictions': softmax, 'dense': fc}                                                                                                                                                                                                                                                                           
    left_output = softmax                                                                                                                                                                                                                                                                                                       
    right_output = fc                                                                                                                                                                                                                                                                                                           
    return Model(inputs=i, outputs=[left_output, right_output])                                                                                                                                                                                                                                                                 
                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                
model = build_model()                                                                                                                                                                                                                                                                                                           
opt = tf.keras.optimizers.Adam(learning_rate=0.003)                                                                                                                                                                                                                                                                             
                                                                                                                                                                                                                                                                                                                                
# model.compile(optimizer=opt, metrics={'predictions': [SparseCategoricalAccuracy()]}, loss={                                                                                                                                                                                                                                   
#     'predictions': SparseCategoricalCrossentropy(from_logits=False)                                                                                                                                                                                                                                                           
# })                                                                                                                                                                                                                                                                                                                            
model.compile(optimizer=opt, metrics="SparseCategoricalAccuracy", loss=[SparseCategoricalCrossentropy(from_logits=False)])                                                                                                                                                                                                      
                                                                                                                                                                                                                                                                                                                                
_ds = Dataset.from_tensor_slices((np.random.rand(10, 5), np.random.randint(0, 5, 10)))                                                                                                                                                                                                                                          
model.fit(_ds.batch(32), epochs=3, verbose=1)                                                                                                                                                                                                                                                                                   
                                                                                                                                                                                                                                                                                                                                
print(type(model.predict([[0.1] * 5])))                                                                                                                                                                                                                                                                                         
print(model.predict([[0.1] * 5]))                                                                                                                                                                                                                                                                                               
# >>> {                                                                                                                                                                                                                                                                                                                         
#       'predictions': array([[                                                                                                                                                                                                                                                                                                 
#                               0.20790772,                                                                                                                                                                                                                                                                                     
#                               0.18233214,                                                                                                                                                                                                                                                                                     
#                               0.18703228,                                                                                                                                                                                                                                                                                     
#                               0.20266992,                                                                                                                                                                                                                                                                                     
#                               0.22005796                                                                                                                                                                                                                                                                                      
#                      ]], dtype=float32),                                                                                                                                                                                                                                                                                      
#       'dense': array([[                                                                                                                                                                                                                                                                                                       
#                         0.1394624,                                                                                                                                                                                                                                                                                            
#                         0.16406271,                                                                                                                                                                                                                                                                                           
#                         0.,                                                                                                                                                                                                                                                                                                   
#                         0.,                                                                                                                                                                                                                                                                                                   
#                         0.                                                                                                                                                                                                                                                                                                    
#                      ]], dtype=float32)}                                                                                                                                                                                                                                                                                      
                                                                                                                                                                                                                                                                                                                                
model.save('my_model')                                                                                                                                                                                                                                                                                                          
# >>> Traceback (most recent call last): ...  

输出

Epoch 1/3
1/1 [==============================] - 0s 207ms/step - loss: 1.7489 - >Softmax_loss: 1.7489 - Softmax_sparse_categorical_accuracy: 0.3000 - >dense_4_sparse_categorical_accuracy: 0.4000
Epoch 2/3
1/1 [==============================] - 0s 1ms/step - loss: 1.7379 - >Softmax_loss: 1.7379 - Softmax_sparse_categorical_accuracy: 0.3000 - >dense_4_sparse_categorical_accuracy: 0.4000
Epoch 3/3
1/1 [==============================] - 0s 920us/step - loss: 1.7271 - >Softmax_loss: 1.7271 - Softmax_sparse_categorical_accuracy: 0.3000 - >dense_4_sparse_categorical_accuracy: 0.4000
WARNING:tensorflow:5 out of the last 7 calls to Model.make_predict_function..predict_function at >0x7f229d14df70> triggered tf.function retracing. Tracing is expensive >and the excessive number of tracings could be due to (1) creating >@tf.function repeatedly in a loop, (2) passing tensors with different >shapes, (3) passing Python objects instead of tensors. For (1), >please define your @tf.function outside of the loop. For (2), >@tf.function has experimental_relax_shapes=True option that relaxes >argument shapes that can avoid unnecessary retracing. For (3), please >refer to >https://www.tensorflow.org/guide/function#controlling_retracing and >https://www.tensorflow.org/api_docs/python/tf/function for more >details.
<class 'list'>
[array([[0.1952633 , 0.19603369, 0.18926372, 0.2137877 , >0.20565161]],
dtype=float32), array([[0.12314494, 0. , 0.13512844, 0. >, 0. ]],
dtype=float32)]
INFO:tensorflow:Assets written to: my_model/assets

找到解决方案和陷阱。

解决方案

解决方案很简单:对于不需要损失(中间层)的输出,不要在 model.compile 中使用的 losses 字典中省略键,而是添加指向 null值。

model.compile(
    optimizer=opt,
    metrics={'predictions': [
        SparseCategoricalAccuracy(),
        SparseTopKCategoricalAccuracy(k=3)
    ]},
    loss={
        'predictions': SparseCategoricalCrossentropy(from_logits=False),
        'fc': None,
    #   ^^^^^^^^^^^
    #   The solution
    })

现在 model.predict() 的输出是一个字典,其键与 model.outputs 中的相同。使用 tf.keras.models.load_model() 加载模型就可以了。

陷阱

在 Tensorflow Serving 中,当您执行推理时,输出将是一个字典,其值是 model.outputs.

中引用的层中的 Layer.name
# Layer names: ("FC_1", "Softmax")
fc = Dense(32, activation='relu', name='FC_1')(layer)
x = Dense(num_cats, activation='softmax', name='Softmax')(fc)

# Keys in model.outputs: ("predictions", "fc")
outputs = {
    'predictions': x,
    'fc': fc,
}

model = Model(inputs=inputs, outputs=outputs)

当您通过 Tensorflow Serving 评估模型时,输出将采用以下形式:

{
    "outputs": {
        "Softmax": [ [ 0.1, ... ] ],
        "FC_1": [ [...] ]
    }
}

您还可以使用 /v1/models/{modelId}/versions/1/metadata REST 端点在 Tensorflow Serving 中查看模型的输出层。