如何在h5中保存一个损失函数为"balanced categorical entropy"的神经网络模型?
How to save a neural network model in h5 with loss function "balanced categorical entropy"?
我正在使用 VGG16 进行图像分割,损失函数为“平衡分类熵”,代码为
beta=0.5
def balanced_cross_entropy(beta):
def loss(y_true, y_pred):
weight_a = beta * tf.cast(y_true, tf.float32)
weight_b = (1 - beta) * tf.cast(1 - y_true, tf.float32)
o = (tf.math.log1p(tf.exp(-tf.abs(y_pred))) + tf.nn.relu(-y_pred)) * (weight_a + weight_b) + y_pred * weight_b
return tf.reduce_mean(o)
return loss
一切正常。现在我使用代码将这个模型保存在 h5 文件中。
vgg.save('vgg.h5')
但是当我使用 Keras 的 load_model 时
model = load_model('vgg.h5', custom_objects={'balanced_cross_entropy(beta)': balanced_cross_entropy(beta)})
我遇到错误。
Unknown loss function: loss. Please ensure this object is passed to the `custom_objects` argument.
任何人都可以帮忙吗,我怀疑问题可能是由于测试版造成的?
如果只想进行推理,可以通过指定
来避免这个问题
model = load_model('vgg.h5',compile=False)
否则需要按以下方式加载:
model = load_model("vgg.h5", custom_objects={'loss': balanced_cross_entropy(beta)})
;在你的代码中你写了 balanced_cross_entropy(beta)
而不是 loss
.
简短说明:
custom_object
中key的名字其实就是内层函数的名字(其实就是balanced_cross_entropy(beta)
返回的;外层函数的名字其实就是<key,value>
custom_object
字典中的对。
我正在使用 VGG16 进行图像分割,损失函数为“平衡分类熵”,代码为
beta=0.5
def balanced_cross_entropy(beta):
def loss(y_true, y_pred):
weight_a = beta * tf.cast(y_true, tf.float32)
weight_b = (1 - beta) * tf.cast(1 - y_true, tf.float32)
o = (tf.math.log1p(tf.exp(-tf.abs(y_pred))) + tf.nn.relu(-y_pred)) * (weight_a + weight_b) + y_pred * weight_b
return tf.reduce_mean(o)
return loss
一切正常。现在我使用代码将这个模型保存在 h5 文件中。
vgg.save('vgg.h5')
但是当我使用 Keras 的 load_model 时
model = load_model('vgg.h5', custom_objects={'balanced_cross_entropy(beta)': balanced_cross_entropy(beta)})
我遇到错误。
Unknown loss function: loss. Please ensure this object is passed to the `custom_objects` argument.
任何人都可以帮忙吗,我怀疑问题可能是由于测试版造成的?
如果只想进行推理,可以通过指定
来避免这个问题model = load_model('vgg.h5',compile=False)
否则需要按以下方式加载:
model = load_model("vgg.h5", custom_objects={'loss': balanced_cross_entropy(beta)})
;在你的代码中你写了 balanced_cross_entropy(beta)
而不是 loss
.
简短说明:
custom_object
中key的名字其实就是内层函数的名字(其实就是balanced_cross_entropy(beta)
返回的;外层函数的名字其实就是<key,value>
custom_object
字典中的对。