如何对具有多个输入和输出的 Tensorflow lite 模型进行推理?
How to do the inference for a Tensorflow lite model with multiple inputs and outputs?
我创建了一个简单的 tensorflow 分类模型,我将其转换并导出为 .tflite 文件。为了将模型集成到我的 android 应用程序中,我遵循了这个 tutorial,但它们仅涵盖 [=33= 的单个 input/output 模型类型]推理部分。
查看文档和其他一些资源后,我实施了以下解决方案:
// acc and gyro X, Y, Z are my features
float[] accX = new float[1];
float[] accY = new float[1];
float[] accZ = new float[1];
float[] gyroX = new float[1];
float[] gyroY = new float[1];
float[] gyroZ = new float[1];
Object[] inputs = new Object[]{accX, accY, accZ, gyroX, gyroY, gyroZ};
// And I have 4 classes
float[] output1 = new float[1];
float[] output2 = new float[1];
float[] output3 = new float[1];
float[] output4 = new float[1];
Map<Integer, Object> outputs = new HashMap<>();
outputs.put(0, output1);
outputs.put(1, output2);
outputs.put(2, output3);
outputs.put(3, output4);
interpreter.runForMultipleInputsOutputs(inputs, outputs);
但是这段代码抛出异常:
java.lang.IllegalArgumentException: Invalid input Tensor index: 1
在这一步我不确定哪里出了问题。
这是我模型的架构:
model = tf.keras.Sequential([
tf.keras.layers.Dense(units=hp_units, input_shape=(6,), activation='relu'),
tf.keras.layers.Dense(240, activation='relu'),
tf.keras.layers.Dense(4, activation='softmax')
])
解法:
根据@Karim Nosseir 的回答,我使用了签名方法来访问我的模型的输入和输出。如果你有一个内置模型 python 那么你可以找到答案中的签名并使用它如下所示:
Python签名:
{'serving_default': {'inputs': ['dense_6_input'], 'outputs': ['dense_8']}}
Android java 使用:
float[] input = new float[6];
float[][] output = new float[1][4];
// Run decoding signature.
try (Interpreter interpreter = new Interpreter(loadModelFile())) {
Map<String, Object> inputs = new HashMap<>();
inputs.put("dense_6_input", input);
Map<String, Object> outputs = new HashMap<>();
outputs.put("dense_8", output);
interpreter.runSignature(inputs, outputs, "serving_default");
} catch (IOException e) {
e.printStackTrace();
}
最简单的是使用签名 API 并使用 inputs/outputs
的签名名称
如果您使用 v2 TFLite Converter,您应该找到定义的签名。
打印定义的签名的示例如下
model = tf.keras.Sequential([
tf.keras.layers.Dense(4, input_shape=(6,), activation='relu'),
tf.keras.layers.Dense(240, activation='relu'),
tf.keras.layers.Dense(4, activation='softmax')
])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
print(interpreter.get_signature_list())
请参阅指南 here,了解如何 运行 Java 和其他语言。
我创建了一个简单的 tensorflow 分类模型,我将其转换并导出为 .tflite 文件。为了将模型集成到我的 android 应用程序中,我遵循了这个 tutorial,但它们仅涵盖 [=33= 的单个 input/output 模型类型]推理部分。 查看文档和其他一些资源后,我实施了以下解决方案:
// acc and gyro X, Y, Z are my features
float[] accX = new float[1];
float[] accY = new float[1];
float[] accZ = new float[1];
float[] gyroX = new float[1];
float[] gyroY = new float[1];
float[] gyroZ = new float[1];
Object[] inputs = new Object[]{accX, accY, accZ, gyroX, gyroY, gyroZ};
// And I have 4 classes
float[] output1 = new float[1];
float[] output2 = new float[1];
float[] output3 = new float[1];
float[] output4 = new float[1];
Map<Integer, Object> outputs = new HashMap<>();
outputs.put(0, output1);
outputs.put(1, output2);
outputs.put(2, output3);
outputs.put(3, output4);
interpreter.runForMultipleInputsOutputs(inputs, outputs);
但是这段代码抛出异常:
java.lang.IllegalArgumentException: Invalid input Tensor index: 1
在这一步我不确定哪里出了问题。
这是我模型的架构:
model = tf.keras.Sequential([
tf.keras.layers.Dense(units=hp_units, input_shape=(6,), activation='relu'),
tf.keras.layers.Dense(240, activation='relu'),
tf.keras.layers.Dense(4, activation='softmax')
])
解法:
根据@Karim Nosseir 的回答,我使用了签名方法来访问我的模型的输入和输出。如果你有一个内置模型 python 那么你可以找到答案中的签名并使用它如下所示:
Python签名:
{'serving_default': {'inputs': ['dense_6_input'], 'outputs': ['dense_8']}}
Android java 使用:
float[] input = new float[6];
float[][] output = new float[1][4];
// Run decoding signature.
try (Interpreter interpreter = new Interpreter(loadModelFile())) {
Map<String, Object> inputs = new HashMap<>();
inputs.put("dense_6_input", input);
Map<String, Object> outputs = new HashMap<>();
outputs.put("dense_8", output);
interpreter.runSignature(inputs, outputs, "serving_default");
} catch (IOException e) {
e.printStackTrace();
}
最简单的是使用签名 API 并使用 inputs/outputs
的签名名称如果您使用 v2 TFLite Converter,您应该找到定义的签名。
打印定义的签名的示例如下
model = tf.keras.Sequential([
tf.keras.layers.Dense(4, input_shape=(6,), activation='relu'),
tf.keras.layers.Dense(240, activation='relu'),
tf.keras.layers.Dense(4, activation='softmax')
])
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
print(interpreter.get_signature_list())
请参阅指南 here,了解如何 运行 Java 和其他语言。