带参数的自定义激活
Custom activation with parameter
我正在尝试在 Keras 中创建一个可以接受参数 beta
的激活函数,如下所示:
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
from keras.layers import Activation
class Swish(Activation):
def __init__(self, activation, beta, **kwargs):
super(Swish, self).__init__(activation, **kwargs)
self.__name__ = 'swish'
self.beta = beta
def swish(x):
return (K.sigmoid(beta*x) * x)
get_custom_objects().update({'swish': Swish(swish, beta=1.)})
没有 beta
参数也能正常运行,但如何在激活定义中包含该参数?当我像激活 ELU 一样 model.to_json()
时,我也希望保存这个值。
更新:我根据@今天的回答写了下面的代码:
from keras.layers import Layer
from keras import backend as K
class Swish(Layer):
def __init__(self, beta, **kwargs):
super(Swish, self).__init__(**kwargs)
self.beta = K.cast_to_floatx(beta)
self.__name__ = 'swish'
def call(self, inputs):
return K.sigmoid(self.beta * inputs) * inputs
def get_config(self):
config = {'beta': float(self.beta)}
base_config = super(Swish, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
from keras.utils.generic_utils import get_custom_objects
get_custom_objects().update({'swish': Swish(beta=1.)})
gnn = keras.models.load_model("Model.h5")
arch = gnn.to_json()
with open(directory + 'architecture.json', 'w') as arch_file:
arch_file.write(arch)
但是,它当前不会在 .json 文件中保存 beta
值。怎样才能让它保值?
既然在序列化模型的时候要保存激活函数的参数,我觉得还是像advanced activations which have been defined in Keras一样把激活函数定义成一层比较好。你可以这样做:
from keras.layers import Layer
from keras import backend as K
class Swish(Layer):
def __init__(self, beta, **kwargs):
super(Swish, self).__init__(**kwargs)
self.beta = K.cast_to_floatx(beta)
def call(self, inputs):
return K.sigmoid(self.beta * inputs) * inputs
def get_config(self):
config = {'beta': float(self.beta)}
base_config = super(Swish, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
然后您可以像使用 Keras 层一样使用它:
# ...
model.add(Swish(beta=0.3))
由于get_config()
方法已经在其定义中实现,所以当使用to_json()
或save()
.
等方法时,参数beta
将被保存
我正在尝试在 Keras 中创建一个可以接受参数 beta
的激活函数,如下所示:
from keras import backend as K
from keras.utils.generic_utils import get_custom_objects
from keras.layers import Activation
class Swish(Activation):
def __init__(self, activation, beta, **kwargs):
super(Swish, self).__init__(activation, **kwargs)
self.__name__ = 'swish'
self.beta = beta
def swish(x):
return (K.sigmoid(beta*x) * x)
get_custom_objects().update({'swish': Swish(swish, beta=1.)})
没有 beta
参数也能正常运行,但如何在激活定义中包含该参数?当我像激活 ELU 一样 model.to_json()
时,我也希望保存这个值。
更新:我根据@今天的回答写了下面的代码:
from keras.layers import Layer
from keras import backend as K
class Swish(Layer):
def __init__(self, beta, **kwargs):
super(Swish, self).__init__(**kwargs)
self.beta = K.cast_to_floatx(beta)
self.__name__ = 'swish'
def call(self, inputs):
return K.sigmoid(self.beta * inputs) * inputs
def get_config(self):
config = {'beta': float(self.beta)}
base_config = super(Swish, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
from keras.utils.generic_utils import get_custom_objects
get_custom_objects().update({'swish': Swish(beta=1.)})
gnn = keras.models.load_model("Model.h5")
arch = gnn.to_json()
with open(directory + 'architecture.json', 'w') as arch_file:
arch_file.write(arch)
但是,它当前不会在 .json 文件中保存 beta
值。怎样才能让它保值?
既然在序列化模型的时候要保存激活函数的参数,我觉得还是像advanced activations which have been defined in Keras一样把激活函数定义成一层比较好。你可以这样做:
from keras.layers import Layer
from keras import backend as K
class Swish(Layer):
def __init__(self, beta, **kwargs):
super(Swish, self).__init__(**kwargs)
self.beta = K.cast_to_floatx(beta)
def call(self, inputs):
return K.sigmoid(self.beta * inputs) * inputs
def get_config(self):
config = {'beta': float(self.beta)}
base_config = super(Swish, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
然后您可以像使用 Keras 层一样使用它:
# ...
model.add(Swish(beta=0.3))
由于get_config()
方法已经在其定义中实现,所以当使用to_json()
或save()
.
beta
将被保存