如何在 TF 2.6.0 / Python 3.9.7 中保存和重新加载子类模型而不会导致性能下降?

How to save and reload a Subclassed model in TF 2.6.0 / Python 3.9.7 wihtout performance drop?

看起来像百万美元的问题。我在 Keras 中通过子类 Model 构建了下面的模型。

模型训练良好且性能良好,但我无法找到一种方法来保存和恢复模型而不会导致显着的性能损失。 我在ROC曲线上跟踪AUC用于异常检测,加载模型后的ROC曲线比之前差,使用完全相同的验证数据集。

我怀疑问题出在 BatchNormalization 上,但我可能错了。

我试过几个选项:

这有效但会导致性能下降。

model.save() / tf.keras.models.load()

这有效但也会导致性能下降:

model.save_weights() / model.load_weights()

这不起作用,我收到以下错误:

tf.saved_model.save() / tf.saved_model.load()

AttributeError: '_UserObject' object has no attribute 'predict'

这也不起作用,因为子类模型不支持 json 导出:

model.to_json()

这是模型:

class Deep_Seq2Seq_Detector(Model):
  def __init__(self, flight_len, param_len, hidden_state=16):
    super(Deep_Seq2Seq_Detector, self).__init__()
    self.input_dim = (None, flight_len, param_len)
    self._name_ = "LSTM"
    self.units = hidden_state
    
    self.regularizer0 = tf.keras.Sequential([
        layers.BatchNormalization()
        ])
    
    self.encoder1 = layers.LSTM(self.units,
                  return_state=False,
                  return_sequences=True,
                  #activation="tanh",
                  name='encoder1',
                  input_shape=self.input_dim)#,
                  #kernel_regularizer= tf.keras.regularizers.l1(),
                  #)
    
    self.regularizer1 = tf.keras.Sequential([
        layers.BatchNormalization(),
        layers.Activation("tanh")
        ])
    
    self.encoder2 = layers.LSTM(self.units,
                  return_state=False,
                  return_sequences=True,
                  #activation="tanh",
                  name='encoder2')#,
                  #kernel_regularizer= tf.keras.regularizers.l1()
                  #) #                    input_shape=(None, self.input_dim[1],self.units),
    
    self.regularizer2 = tf.keras.Sequential([
        layers.BatchNormalization(),
        layers.Activation("tanh")
        ])
    
    self.encoder3 = layers.LSTM(self.units,
                  return_state=True,
                  return_sequences=False,
                  activation="tanh",
                  name='encoder3')#,
                  #kernel_regularizer= tf.keras.regularizers.l1(),
                  #) #                   input_shape=(None, self.input_dim[1],self.units),
    
    self.repeat = layers.RepeatVector(self.input_dim[1])
    
    self.decoder = layers.LSTM(self.units,
                  return_sequences=True,
                  activation="tanh",
                  name="decoder",
                  input_shape=(self.input_dim[1],self.units))
    
    self.dense = layers.TimeDistributed(layers.Dense(self.input_dim[2]))

  @tf.function 
  def call(self, x):
    
    # Encoder
    x0 = self.regularizer0(x)
    x1 = self.encoder1(x0)
    x11 = self.regularizer1(x1)
    
    x2 = self.encoder2(x11)
    x22 = self.regularizer2(x2)
    
    output, hs, cs = self.encoder3(x22)
    
    # see https://www.tensorflow.org/guide/keras/rnn 
    encoded_state = [hs, cs] 
    repeated_vec = self.repeat(output)
    
    # Decoder
    decoded = self.decoder(repeated_vec, initial_state=encoded_state)
    output_decoder = self.dense(decoded)

    return output_decoder

我看过 Git 个话题,但没有直接的答案: https://github.com/keras-team/keras/issues/4875

有人找到解决办法了吗?我必须改用函数式或顺序式 API 吗?

问题似乎出在 Sublcassing API

我使用 Functionnal API 重建了完全相同的模型,现在 model.save / model.load 产生了相似的结果。