无法保存模型架构(bilstm + attention)

Unable to save model architecture (bilstm + attention)

我正在处理多标签文本分类问题。我正在尝试使用 bilstm 模型添加注意力机制。注意力机制代码取自here。我无法保存模型架构并收到下面提到的错误。我的tensorflow版本-2.2.0

from keras import backend as K

def dot_product(x, kernel):
    if K.backend() == 'tensorflow':
        return K.squeeze(K.dot(x, K.expand_dims(kernel)), axis=-1)
    else:
        return K.dot(x, kernel)

class AttentionWithContext(tf.keras.layers.Layer):
    """
    # Input shape
        3D tensor with shape: `(samples, steps, features)`.
    # Output shape
        2D tensor with shape: `(samples, features)`.
    """
    def __init__(self,
                 W_regularizer=None, u_regularizer=None, b_regularizer=None,
                 W_constraint=None, u_constraint=None, b_constraint=None,
                 bias=True, **kwargs):

        self.supports_masking = True
        self.init = tf.keras.initializers.get('glorot_uniform')

        self.W_regularizer = tf.keras.regularizers.get(W_regularizer)
        self.u_regularizer = tf.keras.regularizers.get(u_regularizer)
        self.b_regularizer = tf.keras.regularizers.get(b_regularizer)

        self.W_constraint = tf.keras.constraints.get(W_constraint)
        self.u_constraint = tf.keras.constraints.get(u_constraint)
        self.b_constraint = tf.keras.constraints.get(b_constraint)

        self.bias = bias
        super(AttentionWithContext, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 3

        self.W = self.add_weight(shape=(input_shape[-1], input_shape[-1],),
                                 initializer=self.init,
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)
        if self.bias:
            self.b = self.add_weight(shape=(input_shape[-1],),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)

        self.u = self.add_weight(shape=(input_shape[-1],),
                                 initializer=self.init,
                                 name='{}_u'.format(self.name),
                                 regularizer=self.u_regularizer,
                                 constraint=self.u_constraint)

        super(AttentionWithContext, self).build(input_shape)

    def compute_mask(self, input, input_mask=None):
        # do not pass the mask to the next layers
        return None

    def call(self, x, mask=None):
        uit = dot_product(x, self.W)

        if self.bias:
            uit += self.b

        uit = K.tanh(uit)
        ait = dot_product(uit, self.u)

        a = K.exp(ait)

        # apply mask after the exp. will be re-normalized next
        if mask is not None:
            # Cast the mask to floatX to avoid float64 upcasting in theano
            a *= K.cast(mask, K.floatx())

        # in some cases especially in the early stages of training the sum may be almost zero
        # and this results in NaN's. A workaround is to add a very small positive number ε to the sum.
        # a /= K.cast(K.sum(a, axis=1, keepdims=True), K.floatx())
        a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())

        a = K.expand_dims(a)
        weighted_input = x * a
        return K.sum(weighted_input, axis=1)

    def compute_output_shape(self, input_shape):
        return input_shape[0], input_shape[-1]


def lstm_with_attention(embedding_matrix,
                        **kwargs):
  STAMP = kwargs['STAMP']
  max_seq_length = kwargs['max_seq_length']
  EMBEDDING_DIM = kwargs['EMBEDDING_DIM']
  nb_words = kwargs['nb_words']

  inp = tf.keras.Input(shape=(max_seq_length,))
  embedded_seq = tf.keras.layers.Embedding(nb_words,
                EMBEDDING_DIM,
                weights=[embedding_matrix],
                trainable=False)(inp)
  x_1_bilstm = tf.keras.layers.Bidirectional(tf.compat.v1.keras.layers.CuDNNLSTM(128, return_sequences=True))(embedded_seq)
  x_1_bn = tf.keras.layers.BatchNormalization()(x_1_bilstm)
  x_2_bilstm = tf.keras.layers.Bidirectional(tf.compat.v1.keras.layers.CuDNNLSTM(64, return_sequences=True))(x_1_bn)
  attention = AttentionWithContext()(x_2_bilstm)
  x = tf.keras.layers.Dense(64, activation="relu")(attention)
  x = tf.keras.layers.Dense(1, activation="sigmoid")(x)
  model = tf.keras.Model(inputs=inp, outputs=x)
  optimizer = tf.keras.optimizers.Adam()
  model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
  model.summary()

  with open(STAMP + ".json", "w") as json_file: json_file.write(model.to_json())
  return model, attention

用注意力构建 lstm

embedding_matrix, nb_words = get_embedding('glove',word_index)
model, attention_layer = lstm_with_attention(embedding_matrix,STAMP=STAMP,max_seq_length=max_seq_length,nb_words=nb_words,EMBEDDING_DIM=EMBEDDING_DIM)

错误

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-54-4be6d63890f7> in <module>()
     20 # # BiGRU CuDNN
     21 embedding_matrix, nb_words = get_embedding('glove',word_index)
---> 22 model, attention_layer = lstm_with_attention(embedding_matrix,STAMP=STAMP,max_seq_length=max_seq_length,nb_words=nb_words,EMBEDDING_DIM=EMBEDDING_DIM)
     23 # gru_model = make_cudnn_gru_f(max_seq_length,embedding_matrix,loss_func=macro_soft_f1,eval_metric=macro_f1)
     24 # model = gru_model()

7 frames
<ipython-input-51-1ae8a90521d0> in lstm_with_attention(embedding_matrix, **kwargs)
    115   model.summary()
    116 
--> 117   with open(STAMP + ".json", "w") as json_file: json_file.write(model.to_json())
    118   return model, attention
    119 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py in to_json(self, **kwargs)
   1296         A JSON string.
   1297     """
-> 1298     model_config = self._updated_config()
   1299     return json.dumps(
   1300         model_config, default=serialization.get_json_type, **kwargs)

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py in _updated_config(self)
   1274     from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
   1275 
-> 1276     config = self.get_config()
   1277     model_config = {
   1278         'class_name': self.__class__.__name__,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py in get_config(self)
    966     if not self._is_graph_network:
    967       raise NotImplementedError
--> 968     return copy.deepcopy(get_network_config(self))
    969 
    970   @classmethod

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py in get_network_config(network, serialize_layer_fn)
   2117           filtered_inbound_nodes.append(node_data)
   2118 
-> 2119     layer_config = serialize_layer_fn(layer)
   2120     layer_config['name'] = layer.name
   2121     layer_config['inbound_nodes'] = filtered_inbound_nodes

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
    273         return serialize_keras_class_and_config(
    274             name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
--> 275       raise e
    276     serialization_config = {}
    277     for key, item in config.items():

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
    268     name = get_registered_name(instance.__class__)
    269     try:
--> 270       config = instance.get_config()
    271     except NotImplementedError as e:
    272       if _SKIP_FAILED_SERIALIZATION:

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in get_config(self)
    634       raise NotImplementedError('Layer %s has arguments in `__init__` and '
    635                                 'therefore must override `get_config`.' %
--> 636                                 self.__class__.__name__)
    637     return config
    638 

NotImplementedError: Layer AttentionWithContext has arguments in `__init__` and therefore must override `get_config`.

这是一个 TensorFlow 功能,用于通知您它不知道如何重建您的层,因为它不知道如何处理您的配置。这是来自 tensorflow 文档的引述:

get_config()

Returns the config of the layer.

A layer config is a Python dictionary (serializable) containing the configuration of a layer. The same layer can be reinstantiated later (without its trained weights) from this configuration.

The config of a layer does not include connectivity information, nor the layer class name. These are handled by Network (one layer of abstraction above).

要解决这个问题,你只需要在你的class中创建一个__init__对应的get_config方法来指示TensorFlow如何重新实例化你的层。

def get_config(self):
    config = super().get_config().copy()
    config.update({
            'W_regularizer': self.W_regularizer,
            'u_regularizer': self.u_regularizer,
            'b_regularizer': self.b_regularizer,
            'W_constraint': self.W_constraint,
            'u_constraint': self.u_constraint,
            'b_constraint': self.b_constraint,
            'bias': self.bias,
    })
    return config

那么你应该可以保存和加载它了。