如何在 C++ 中将火炬模型定义为函数的输入
how to define a torch model as an input of a function in c++
我正在加载一个在 python 中训练过的 C++ 模型。现在我想编写一个函数来测试带有随机输入的模型,但我不能将模型定义为函数的参数。我试过 struct 但它不起作用。
void test(vector<struct comp*>& model){
//pseudo input
vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1,3,224, 224}));
at::Tensor output = model[0]->forward(inputs).toTensor();
cout << output << endl;
}
int main(int argc, char *argv[]) {
if (argc == 2){
cout << argv[1] << endl;
//model = load_model(argv[1]);
torch::jit::script::Module module = torch::jit::load(argv[1]);
}
else {
cerr << "no path of model is given" << endl;
}
// test
vector<struct comp*> modul;
modul.push_back(module);
test(modul);
}
编辑:您需要将 module
变量放入范围内!
您的基本类型是 torch::jit::script::Module
所以为其定义一个名称:
using module_type = torch::jit::script::Module;
然后在您的代码中使用它,还使用 const
引用作为只读参数:
void test(const vector<module_type>& model){
//pseudo input
vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1,3,224, 224}));
at::Tensor output = model[0]->forward(inputs).toTensor();
cout << output << endl;
}
int main(int argc, char *argv[]) {
if (argc == 2){
cout << argv[1] << endl;
}
else {
cerr << "no path of model is given" << endl;
return -1;
}
// test
module_type module = torch::jit::load(argv[1]);;
vector<module_type> modul;
modul.push_back(module);
test(modul);
}
我正在加载一个在 python 中训练过的 C++ 模型。现在我想编写一个函数来测试带有随机输入的模型,但我不能将模型定义为函数的参数。我试过 struct 但它不起作用。
void test(vector<struct comp*>& model){
//pseudo input
vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1,3,224, 224}));
at::Tensor output = model[0]->forward(inputs).toTensor();
cout << output << endl;
}
int main(int argc, char *argv[]) {
if (argc == 2){
cout << argv[1] << endl;
//model = load_model(argv[1]);
torch::jit::script::Module module = torch::jit::load(argv[1]);
}
else {
cerr << "no path of model is given" << endl;
}
// test
vector<struct comp*> modul;
modul.push_back(module);
test(modul);
}
编辑:您需要将 module
变量放入范围内!
您的基本类型是 torch::jit::script::Module
所以为其定义一个名称:
using module_type = torch::jit::script::Module;
然后在您的代码中使用它,还使用 const
引用作为只读参数:
void test(const vector<module_type>& model){
//pseudo input
vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1,3,224, 224}));
at::Tensor output = model[0]->forward(inputs).toTensor();
cout << output << endl;
}
int main(int argc, char *argv[]) {
if (argc == 2){
cout << argv[1] << endl;
}
else {
cerr << "no path of model is given" << endl;
return -1;
}
// test
module_type module = torch::jit::load(argv[1]);;
vector<module_type> modul;
modul.push_back(module);
test(modul);
}