使用 python onnxruntime 进行预测时出错
Error Making prediction with python onnxruntime
我使用 sklearn
库创建了一个非常基本的决策树。这棵树是根据 4 个特征训练的:
feat1 INT
feat2 INT
feat3 FLOAT
feat4 FLOAT
并且 label/target 特征是一个布尔值(0 或 1)。
我将树转换为 ONNX
格式,现在我想使用 onnxruntime python
库进行预测。我在互联网上找到了示例代码来执行此操作。问题是我不明白这段代码、函数和参数的所有部分到底发生了什么。这导致我出错。我确实搜索了一些文档,但找不到这个。
在下面的代码中,我将树模型转换为 ONNX
格式。这是成功的,但我不明白部分代码。在 initial_type
变量中,根据我之前提到的 4 个特征列和 label/target 特征,我必须在这里输入什么?现在我输入了 FloatTensorType([None, 4]
,因为我有 4 个特征列,我不知道 None
是什么。
##Convert to ONNX format
initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(treeModel, initial_types=initial_type)
with open("path", "wb") as f:
f.write(onx.SerializeToString())
在下面的代码中,我想使用 onnxruntime
库进行预测,但出现此错误:
RuntimeError: Either type_proto was null or it was not of sequence type
这是因为我看不懂下面最后一行代码。我输入这个 {input_name: [4, 8, 77.8, 143.45]
因为这是特征列的四个值。我在这里做错了什么?
sess = rt.InferenceSession("pathToONNXModel")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: [4, 8, 77.8, 143.45]})[0]
你试过了吗{input_name: numpy.array([4, 8, 77.8, 143.45], dtype=numpy.float32)}
? onnxruntime 需要 numpy 数组作为输入。
我使用 sklearn
库创建了一个非常基本的决策树。这棵树是根据 4 个特征训练的:
feat1 INT
feat2 INT
feat3 FLOAT
feat4 FLOAT
并且 label/target 特征是一个布尔值(0 或 1)。
我将树转换为 ONNX
格式,现在我想使用 onnxruntime python
库进行预测。我在互联网上找到了示例代码来执行此操作。问题是我不明白这段代码、函数和参数的所有部分到底发生了什么。这导致我出错。我确实搜索了一些文档,但找不到这个。
在下面的代码中,我将树模型转换为 ONNX
格式。这是成功的,但我不明白部分代码。在 initial_type
变量中,根据我之前提到的 4 个特征列和 label/target 特征,我必须在这里输入什么?现在我输入了 FloatTensorType([None, 4]
,因为我有 4 个特征列,我不知道 None
是什么。
##Convert to ONNX format
initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(treeModel, initial_types=initial_type)
with open("path", "wb") as f:
f.write(onx.SerializeToString())
在下面的代码中,我想使用 onnxruntime
库进行预测,但出现此错误:
RuntimeError: Either type_proto was null or it was not of sequence type
这是因为我看不懂下面最后一行代码。我输入这个 {input_name: [4, 8, 77.8, 143.45]
因为这是特征列的四个值。我在这里做错了什么?
sess = rt.InferenceSession("pathToONNXModel")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: [4, 8, 77.8, 143.45]})[0]
你试过了吗{input_name: numpy.array([4, 8, 77.8, 143.45], dtype=numpy.float32)}
? onnxruntime 需要 numpy 数组作为输入。