无法加载经过 keras 训练的模型
Not able to load keras trained model
我正在使用以下代码来训练 HAN 网络。
Code Link
我已经成功地训练了模型,但是当我尝试使用 keras load_model 加载模型时,出现了以下错误-
未知层:AttentionWithContext
根据您分享的link,您的模型有一个明确定义的层 AttentionWithContext() 添加到模型中。当您尝试使用 keras 的 load_model 加载模型时,该方法显示错误,因为该层不是内置在 keras 中,要解决此问题,您可能必须在加载模型之前在代码中再次定义该层使用 load_model。在尝试加载模型之前,请尝试在您提供的 link (https://www.kaggle.com/hsankesara/news-classification-using-han/notebook) 中编写 class AttentionWithContext(layer)。
在AttentionWithContext.py文件中添加以下函数:
def create_custom_objects():
instance_holder = {"instance": None}
class ClassWrapper(AttentionWithContext):
def __init__(self, *args, **kwargs):
instance_holder["instance"] = self
super(ClassWrapper, self).__init__(*args, **kwargs)
def loss(*args):
method = getattr(instance_holder["instance"], "loss_function")
return method(*args)
def accuracy(*args):
method = getattr(instance_holder["instance"], "accuracy")
return method(*args)
return {"ClassWrapper": ClassWrapper ,"AttentionWithContext": ClassWrapper, "loss": loss,
"accuracy":accuracy}
加载模型时:
from AttentionWithContext import create_custom_objects
model = keras.models.load_model(model_path, custom_objects=create_custom_objects())
model.evaluate(X_test, y_test) # or model.predict
我正在使用以下代码来训练 HAN 网络。 Code Link
我已经成功地训练了模型,但是当我尝试使用 keras load_model 加载模型时,出现了以下错误- 未知层:AttentionWithContext
根据您分享的link,您的模型有一个明确定义的层 AttentionWithContext() 添加到模型中。当您尝试使用 keras 的 load_model 加载模型时,该方法显示错误,因为该层不是内置在 keras 中,要解决此问题,您可能必须在加载模型之前在代码中再次定义该层使用 load_model。在尝试加载模型之前,请尝试在您提供的 link (https://www.kaggle.com/hsankesara/news-classification-using-han/notebook) 中编写 class AttentionWithContext(layer)。
在AttentionWithContext.py文件中添加以下函数:
def create_custom_objects():
instance_holder = {"instance": None}
class ClassWrapper(AttentionWithContext):
def __init__(self, *args, **kwargs):
instance_holder["instance"] = self
super(ClassWrapper, self).__init__(*args, **kwargs)
def loss(*args):
method = getattr(instance_holder["instance"], "loss_function")
return method(*args)
def accuracy(*args):
method = getattr(instance_holder["instance"], "accuracy")
return method(*args)
return {"ClassWrapper": ClassWrapper ,"AttentionWithContext": ClassWrapper, "loss": loss,
"accuracy":accuracy}
加载模型时:
from AttentionWithContext import create_custom_objects
model = keras.models.load_model(model_path, custom_objects=create_custom_objects())
model.evaluate(X_test, y_test) # or model.predict