如何对具有多个输入和输出的 Tensorflow lite 模型进行推理?

How to do the inference for a Tensorflow lite model with multiple inputs and outputs?

我创建了一个简单的 tensorflow 分类模型,我将其转换并导出为 .tflite 文件。为了将模型集成到我的 android 应用程序中,我遵循了这个 tutorial,但它们仅涵盖 [=33= 的单个 input/output 模型类型]推理部分。 查看文档和其他一些资源后,我实施了以下解决方案:

        // acc and gyro X, Y, Z are my features
        float[] accX = new float[1];
        float[] accY = new float[1];
        float[] accZ = new float[1];

        float[] gyroX = new float[1];
        float[] gyroY = new float[1];
        float[] gyroZ = new float[1];


        Object[] inputs = new Object[]{accX, accY, accZ, gyroX, gyroY, gyroZ};
        
        // And I have 4 classes
        float[] output1 = new float[1];
        float[] output2 = new float[1];
        float[] output3 = new float[1];
        float[] output4 = new float[1];

        Map<Integer, Object> outputs = new HashMap<>();
        outputs.put(0, output1);
        outputs.put(1, output2);
        outputs.put(2, output3);
        outputs.put(3, output4);

        interpreter.runForMultipleInputsOutputs(inputs, outputs);

但是这段代码抛出异常:

java.lang.IllegalArgumentException: Invalid input Tensor index: 1

在这一步我不确定哪里出了问题。

这是我模型的架构:

 model = tf.keras.Sequential([
        tf.keras.layers.Dense(units=hp_units, input_shape=(6,), activation='relu'),
        tf.keras.layers.Dense(240, activation='relu'),
        tf.keras.layers.Dense(4, activation='softmax')
    ])

解法:

根据@Karim Nosseir 的回答,我使用了签名方法来访问我的模型的输入和输出。如果你有一个内置模型 python 那么你可以找到答案中的签名并使用它如下所示:

Python签名:

{'serving_default': {'inputs': ['dense_6_input'], 'outputs': ['dense_8']}}

Android java 使用:

        float[] input = new float[6];
        float[][] output = new float[1][4];
        
        // Run decoding signature.
        try (Interpreter interpreter = new Interpreter(loadModelFile())) {
            Map<String, Object> inputs = new HashMap<>();
            inputs.put("dense_6_input", input);

            Map<String, Object> outputs = new HashMap<>();
            outputs.put("dense_8", output);

            interpreter.runSignature(inputs, outputs, "serving_default");
        } catch (IOException e) {
            e.printStackTrace();
        }

最简单的是使用签名 API 并使用 inputs/outputs

的签名名称

如果您使用 v2 TFLite Converter,您应该找到定义的签名。

打印定义的签名的示例如下

model = tf.keras.Sequential([
        tf.keras.layers.Dense(4, input_shape=(6,), activation='relu'),
        tf.keras.layers.Dense(240, activation='relu'),
        tf.keras.layers.Dense(4, activation='softmax')
    ])

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
print(interpreter.get_signature_list())

请参阅指南 here,了解如何 运行 Java 和其他语言。