将 PyTorch 模型转换为 tf.keras 时无法加载 .pb

Unable to load .pb while converting PyTorch model to tf.keras

上下文

我正在将 tf.keras 用于个人项目,我需要检索预训练的 Alexnet 模型。 不幸的是,仅使用 tf.keras 无法直接访问此模型,因此我使用 PyTorch 下载了预训练模型,将其转换为 onnx 文件,然后使用以下代码将其导出为 .pb 文件:

torch_pretrained = torchvision.models.alexnet()
torch_pretrained.load_state_dict(torch.load("alexnet.pth"))

dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(torch_pretrained, dummy_input, "alexnet_pretrained.onnx")

onnx_pretrained = onnx.load("alexnet_pretrained.onnx")
onnx_pretrained = prepare(onnx_pretrained)
onnx_pretrained.export_graph('alexnet')

问题

我现在正尝试使用 keras 检索 .pb 文件,正如 所解释的那样,代码如下:

model = tf.keras.models.load_model("alexnet")
model.summary()

我得到一个错误:

AttributeError: '_UserObject' object has no attribute 'summary'

我在加载模型时也收到警告,但我认为这无关紧要:

WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), NOT tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.

加载的模型有一个非常模糊的类型,如您所见:

<tensorflow.python.saved_model.load.Loader._recreate_base_user_object.._UserObject object at 0x0000023137981BB0>

在进行研究时,我发现 this 这意味着我不是唯一一个面临这个问题的人。

问题

最简单的方法是解决这个具体问题,但如果有人知道另一种将预训练的 Alexnet 模型加载到 tf.keras 的方法,这也可以解决我的实际问题。

规格

Windows 10
python 3.9.7
tensorflow 2.6.0
torch 1.10.2
torchvision 0.11.3
onnx 1.10.2
onnx-tf 1.9.0

解决方案

我听从了 Jakub 的建议:我安装了“pytorch2keras”(参见 this)。 我只是运行直接将pytorch模型转为keras模型的函数,果然有效

我只需要修改模块的代码,因为存在一些依赖性问题(他们正在使用 onnx.optimizer,现在称为 onnxoptimizer)所以我只是更改了导入行:

来自

from onnx import optimizer

import onnxoptimizer as optimizer