从 python 中的 ONNX 模型获取预测
Getting a prediction from an ONNX model in python
我找不到任何人向外行解释如何将 onnx 模型加载到 python 脚本中,然后在输入图像时使用该模型进行预测。我只能找到这些代码行:
sess = rt.InferenceSession("onnx_model.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred = sess.run([label_name], {input_name: X.astype(np.float32)})[0]
但我不知道这是什么意思。我到处看,每个人似乎都已经知道他们的意思,所以没有人解释。如果我可以 运行 这段代码,那将是一回事,但我不能。它给了我这个错误:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid rank for input: Input3 Got: 2 Expected: 4 Please fix either the inputs or the model.
所以我需要真正知道这些东西的含义,这样我才能弄清楚如何修复错误。请懂行的人解释一下?
让我们首先检查您提供的代码,让一切都清楚。
sess = ort.InferenceSession("onnx_model.onnx")
此行将模型加载到会话对象中。这意味着模型中使用的层、函数和权重已准备好执行推理。
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
get_inputs
和 get_outputs
这两种方法各自检索有关模型的一些元信息,即模型期望的输入以及它可以提供的输出。在这些行中的元信息之外,实际上只使用了第一个输入和输出,而在这些之外,只有名称被获取并保存到变量中。
对于最后一行,让我们一部分一部分地处理。
pred = sess.run(...)[0]
这会对模型执行推理,之后我们将检查此方法的输入,但目前,输出是不同输出的列表。这些输出都是每个 numpy 数组。在这种情况下,仅使用此列表中的第一个输出,并将其保存到 pred
变量
([label_name], {input_name: X.astype(np.float32)})
这些是 sess.run
的输入。第一个是您希望会话计算的输出名称列表。第二个参数是一个字典,其中每个输入的名称映射到 numpy 数组。这些数组应与模型创建期间提供的数组具有相同的维度。同样,这些数组的类型也应与创建模型期间使用的类型相匹配。
您遇到的错误似乎表明提供的数组没有预期的尺寸。这些预期的维度数量似乎是 4。
为了清楚地了解输入数组的确切形状和数据类型应该是什么,可以使用可视化工具,例如 Netron
我找不到任何人向外行解释如何将 onnx 模型加载到 python 脚本中,然后在输入图像时使用该模型进行预测。我只能找到这些代码行:
sess = rt.InferenceSession("onnx_model.onnx")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred = sess.run([label_name], {input_name: X.astype(np.float32)})[0]
但我不知道这是什么意思。我到处看,每个人似乎都已经知道他们的意思,所以没有人解释。如果我可以 运行 这段代码,那将是一回事,但我不能。它给了我这个错误:
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid rank for input: Input3 Got: 2 Expected: 4 Please fix either the inputs or the model.
所以我需要真正知道这些东西的含义,这样我才能弄清楚如何修复错误。请懂行的人解释一下?
让我们首先检查您提供的代码,让一切都清楚。
sess = ort.InferenceSession("onnx_model.onnx")
此行将模型加载到会话对象中。这意味着模型中使用的层、函数和权重已准备好执行推理。
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
get_inputs
和 get_outputs
这两种方法各自检索有关模型的一些元信息,即模型期望的输入以及它可以提供的输出。在这些行中的元信息之外,实际上只使用了第一个输入和输出,而在这些之外,只有名称被获取并保存到变量中。
对于最后一行,让我们一部分一部分地处理。
pred = sess.run(...)[0]
这会对模型执行推理,之后我们将检查此方法的输入,但目前,输出是不同输出的列表。这些输出都是每个 numpy 数组。在这种情况下,仅使用此列表中的第一个输出,并将其保存到 pred
变量
([label_name], {input_name: X.astype(np.float32)})
这些是 sess.run
的输入。第一个是您希望会话计算的输出名称列表。第二个参数是一个字典,其中每个输入的名称映射到 numpy 数组。这些数组应与模型创建期间提供的数组具有相同的维度。同样,这些数组的类型也应与创建模型期间使用的类型相匹配。
您遇到的错误似乎表明提供的数组没有预期的尺寸。这些预期的维度数量似乎是 4。
为了清楚地了解输入数组的确切形状和数据类型应该是什么,可以使用可视化工具,例如 Netron