NotImplementedError: Layer attention has arguments in `__init__` and therefore must override `get_config`
NotImplementedError: Layer attention has arguments in `__init__` and therefore must override `get_config`
我已经按照 link 中的建议实施了自定义注意力层:
class attention(Layer):
def __init__(self, return_sequences=True):
self.return_sequences = return_sequences
super(attention,self).__init__()
def build(self, input_shape):
self.W=self.add_weight(name="att_weight", shape=(input_shape[-1],1),
initializer="normal")
self.b=self.add_weight(name="att_bias", shape=(input_shape[1],1),
initializer="zeros")
super(attention,self).build(input_shape)
def call(self, x):
e = K.tanh(K.dot(x,self.W)+self.b)
a = K.softmax(e, axis=1)
output = x*a
if self.return_sequences:
return output
return K.sum(output, axis=1)
代码 运行 但是当模型需要保存时我遇到了这个错误。
NotImplementedError:层注意在 __init__
中有参数,因此必须覆盖 get_config
。
一些评论建议覆盖 get_config。
“此错误让您知道 tensorflow 无法保存您的模型,因为它无法加载它。
具体来说,它将无法重新实例化您的自定义层 类。
要解决这个问题,只需根据您添加的新参数覆盖他们的 get_config 方法即可。"
Link 评论:
我的问题是,基于上面的自定义注意力层,我该如何编写 get_config 来解决这个错误?
您需要这样的配置方法:
def get_config(self):
config = super().get_config().copy()
config.update({
'return_sequences': self.return_sequences
})
return config
所有需要的信息都在您链接的另一个 中。
我已经按照 link 中的建议实施了自定义注意力层:
class attention(Layer):
def __init__(self, return_sequences=True):
self.return_sequences = return_sequences
super(attention,self).__init__()
def build(self, input_shape):
self.W=self.add_weight(name="att_weight", shape=(input_shape[-1],1),
initializer="normal")
self.b=self.add_weight(name="att_bias", shape=(input_shape[1],1),
initializer="zeros")
super(attention,self).build(input_shape)
def call(self, x):
e = K.tanh(K.dot(x,self.W)+self.b)
a = K.softmax(e, axis=1)
output = x*a
if self.return_sequences:
return output
return K.sum(output, axis=1)
代码 运行 但是当模型需要保存时我遇到了这个错误。
NotImplementedError:层注意在 __init__
中有参数,因此必须覆盖 get_config
。
一些评论建议覆盖 get_config。
“此错误让您知道 tensorflow 无法保存您的模型,因为它无法加载它。 具体来说,它将无法重新实例化您的自定义层 类。
要解决这个问题,只需根据您添加的新参数覆盖他们的 get_config 方法即可。"
Link 评论:
我的问题是,基于上面的自定义注意力层,我该如何编写 get_config 来解决这个错误?
您需要这样的配置方法:
def get_config(self):
config = super().get_config().copy()
config.update({
'return_sequences': self.return_sequences
})
return config
所有需要的信息都在您链接的另一个