Python OpenCV RTrees 无法正确加载
Python OpenCV RTrees does not load properly
我有一个使用 OpenCV 版本 3.3.1 训练的 RTrees 模型,我已经保存了它,需要稍后加载以进行预测。
我正在尝试使用 OpenCV 的 RTrees 创建随机森林分类器。但是,我无法成功加载使用保存功能保存的模型。
我使用的代码示例如下:
import numpy as np
import cv2
def train(samples, class_labels, save_file=None):
"""
samples : np.ndarray of type np.float32
class_labels : np.ndarray of type int
save_file : str of the absolute path to the save file
"""
model = cv2.ml.RTrees_create()
# Set paremeters
model.setMaxDepth(20)
model.setActiveVarCount(0)
term_type, n_trees, epsilon = cv2.TERM_CRITERIA_MAX_ITER, 128, 1
model.setTermCriteria((term_type, n_trees, epsilon))
train_data = cv2.ml.TrainData_create(samples=samples,
layout=cv2.ml.ROW_SAMPLE,
responses=class_labels)
model.train(trainData=train_data)
if save_file:
model.save(save_file)
def test(samples, load_file):
"""
samples : np.ndarray of type np.float32
load_file : str of the absolute path to the load file
"""
model = cv2.ml.RTrees_create()
model.load(load_file)
_ret, responses = model.predict(samples)
return responses.ravel()
if __name__ == '__main__':
model_file = '/home/user/bin/tmp/m.xml'
x = np.random.randint(0, 5, (1000, 15))
y = np.random.randint(0, 4, 1000)
train(x, y, model_file)
x_test = np.random.randint(0, 5, (100, 15))
test(x_test, model_file)
错误如下:
Traceback (most recent call last):
File "tmp.py", line 45, in <module>
test(x_test, model_file)
File "tmp.py", line 34, in test
_ret, responses = model.predict(samples)
cv2.error: OpenCV(3.4.2) /io/opencv/modules/ml/src/tree.cpp:1498: error: (-215:Assertion failed) !roots.empty() in function 'predict'
问题是加载函数是 returns 加载模型的静态函数。
OpenCV的文档和例子没有具体说明这一点。
...
model = cv2.ml.RTrees_create()
model = model.load(file_name)
# Or:
model = cv2.ml.RTrees_load(file_name)
请参阅 OpenCV 网站上的 this。
我有一个使用 OpenCV 版本 3.3.1 训练的 RTrees 模型,我已经保存了它,需要稍后加载以进行预测。
我正在尝试使用 OpenCV 的 RTrees 创建随机森林分类器。但是,我无法成功加载使用保存功能保存的模型。 我使用的代码示例如下:
import numpy as np
import cv2
def train(samples, class_labels, save_file=None):
"""
samples : np.ndarray of type np.float32
class_labels : np.ndarray of type int
save_file : str of the absolute path to the save file
"""
model = cv2.ml.RTrees_create()
# Set paremeters
model.setMaxDepth(20)
model.setActiveVarCount(0)
term_type, n_trees, epsilon = cv2.TERM_CRITERIA_MAX_ITER, 128, 1
model.setTermCriteria((term_type, n_trees, epsilon))
train_data = cv2.ml.TrainData_create(samples=samples,
layout=cv2.ml.ROW_SAMPLE,
responses=class_labels)
model.train(trainData=train_data)
if save_file:
model.save(save_file)
def test(samples, load_file):
"""
samples : np.ndarray of type np.float32
load_file : str of the absolute path to the load file
"""
model = cv2.ml.RTrees_create()
model.load(load_file)
_ret, responses = model.predict(samples)
return responses.ravel()
if __name__ == '__main__':
model_file = '/home/user/bin/tmp/m.xml'
x = np.random.randint(0, 5, (1000, 15))
y = np.random.randint(0, 4, 1000)
train(x, y, model_file)
x_test = np.random.randint(0, 5, (100, 15))
test(x_test, model_file)
错误如下:
Traceback (most recent call last):
File "tmp.py", line 45, in <module>
test(x_test, model_file)
File "tmp.py", line 34, in test
_ret, responses = model.predict(samples)
cv2.error: OpenCV(3.4.2) /io/opencv/modules/ml/src/tree.cpp:1498: error: (-215:Assertion failed) !roots.empty() in function 'predict'
问题是加载函数是 returns 加载模型的静态函数。 OpenCV的文档和例子没有具体说明这一点。
...
model = cv2.ml.RTrees_create()
model = model.load(file_name)
# Or:
model = cv2.ml.RTrees_load(file_name)
请参阅 OpenCV 网站上的 this。