将 SavedModel 转换为 TFLite 时不支持操作 ParseExample

Operation ParseExample not supported while converting SavedModel to TFLite

我正在使用 TensorFlow 估算器来训练和保存模型,然后将其转换为 .tflite。我将模型保存如下:

feat_cols = [tf.feature_column.numeric_column('feature1'),
             tf.feature_column.numeric_column('feature2'),
             tf.feature_column.numeric_column('feature3'),
             tf.feature_column.numeric_column('feature4')]

def serving_input_receiver_fn():
    """An input receiver that expects a serialized tf.Example."""
    feature_spec = tf.feature_column.make_parse_example_spec(feat_cols)
    default_batch_size = 1
    serialized_tf_example = tf.placeholder(dtype=tf.string, shape=[default_batch_size], name='tf_example')
    receiver_tensors = {'examples': serialized_tf_example}
    features = tf.parse_example(serialized_tf_example, feature_spec)
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)


dnn_regressor.export_saved_model(export_dir_base='model',
                                 serving_input_receiver_fn=serving_input_receiver_fn)

当我尝试使用以下方法转换生成的 .pb 文件时:

tflite_convert --output_file=/tmp/foo.tflite --saved_model_dir=/tmp/saved_model

我收到一个异常,说 TensorFlow Lite 不支持 ParseExample 操作。

Some of the operators in the model are not supported by the standard TensorFlow Lite runtime. If those are native TensorFlow operators, you might be able to use the extended runtime by passing --enable_select_tf_ops, or by setting target_ops=TFLITE_BUILTINS,SELECT_TF_OPS when calling tf.lite.TFLiteConverter(). Otherwise, if you have a custom implementation for them you can disable this error with --allow_custom_ops, or by setting allow_custom_ops=True when calling tf.lite.TFLiteConverter(). Here is a list of builtin operators you are using: CONCATENATION, FULLY_CONNECTED, RESHAPE. Here is a list of operators for which you will need custom implementations: ParseExample.

如果我尝试 导出模型而不序列化 ,当我尝试预测生成的 .pb 文件时,函数需要清空 set(),而不是 dict我正在传递的输入。

ValueError: Got unexpected keys in input_dict: {'feature1', 'feature2', 'feature3', 'feature4'} expected: set()

我做错了什么?这是在不进行任何序列化的情况下尝试保存模型的代码

features = {
    'feature1': tf.placeholder(dtype=tf.float32, shape=[1], name='feature1'),
    'feature2': tf.placeholder(dtype=tf.float32, shape=[1], name='feature2'),
    'feature3': tf.placeholder(dtype=tf.float32, shape=[1], name='feature3'),
    'feature4': tf.placeholder(dtype=tf.float32, shape=[1], name='feature4')
}

def serving_input_receiver_fn():
    return tf.estimator.export.ServingInputReceiver(features, features)


dnn_regressor.export_savedmodel(export_dir_base='model', serving_input_receiver_fn=serving_input_receiver_fn, as_text=True)

已解决

使用 build_raw_serving_input_receiver_fn 我设法在没有任何序列化的情况下导出保存的模型:

serve_input_fun = tf.estimator.export.build_raw_serving_input_receiver_fn(
    features,
    default_batch_size=None
)

dnn_regressor.export_savedmodel(
    export_dir_base="model",
    serving_input_receiver_fn=serve_input_fun,
    as_text=True
)

注意:进行预测时,Predictor 不知道默认值 signature_def 所以我需要指定它:

predict_fn = predictor.from_saved_model("model/155482...", signature_def_key="predict")

我还从 .pb 转换为 .tflite 我使用了 Python API 因为我还需要指定 signature_def:

converter = tf.contrib.lite.TFLiteConverter.from_saved_model('model/155482....', signature_key='predict')