Tensorflow 给出与 Keras 不同的预测

Tensorflow gives different prediction than Keras

我有一个使用 1.10 Tensorflow 后端在 Keras 中训练的模型,我想使用 Tensorflow 2.4 进行推理。

我将 .h5 模型转换为 SavedModel 格式:

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.python.saved_model import builder
from tensorflow.python.saved_model.signature_def_utils import predict_signature_def
from tensorflow.python.saved_model import tag_constants

def export_h5_to_pb(path_to_h5, export_path):
    if tf.executing_eagerly():
        tf.compat.v1.disable_eager_execution()
    loaded_model = load_model(path_to_h5)
    b = builder.SavedModelBuilder(export_path)
    signature = predict_signature_def(inputs={"inputs": loaded_model.input},
                                      outputs={"score": loaded_model.output})
    session = tf.compat.v1.Session()
    init_op = tf.compat.v1.global_variables_initializer()
    session.run(init_op)
    b.add_meta_graph_and_variables(
        sess=session, tags=[tag_constants.SERVING], signature_def_map={"serving_default": signature})
    b.save()
    
export_h5_to_pb('./trained_nework_VGG3_5comp.h5', './export/Servo/1')

Tensorflow 预测给我:

import tensorflow as tf
imported = tf.saved_model.load('./export/Servo/1', tags='serve')

f = imported.signatures["serving_default"]
f(inputs=tf.constant(test_payload))

> {'score': <tf.Tensor: shape=(1, 6), dtype=float32, numpy=array([[0.16693498, 0.16678475, 0.16666655, 0.16653116, 0.16678214,
     0.16630043]], dtype=float32)>}

虽然原始(正确的)Keras 预测给出:

from keras.models import load_model

model = load_model('./trained_nework_VGG3_5comp.h5')
model.predict(test_payload)

> array([[1.0000000e+00, 3.0078113e-09, 2.0143587e-10, 5.7580127e-09, 1.9100479e-09, 4.1776910e-10]], dtype=float32)

我做错了什么?

我有一个非常相似的问题,您已回复 here,我将分享对我有用的方法。如果你这样做是为了 Sagemaker/AWS(根据文件目录路径和使用“payload”这个词我假设你是?),那么问题是由 TensorFlow 版本差异引起的。

在我找到的所有博客中(例如 this one),他们在使用 TensorflowModel 加载模型时使用 framework_version 1.12。因此,我将 Sagemaker Jupyter 实例中的 TensorFlow 重新安装到版本 1.12,使用 1.12 重新训练我的模型,并将 framework_version 更改为 1.12,这对我有用。如果您没有使用适用于 AWS 的模型,那么这可能不适用,但如果您使用了,那么这是一个潜在的解决方案。祝你好运!