使用 LibTorch (PyTorch) 时在 C++ 中将 at::Tensor 转换为 double
Convert at::Tensor to double in C++ when using LibTorch (PyTorch)
在下面的代码中,我想比较 loss
(数据类型 at::Tensor
)和 lossThreshold
(数据类型 double
)。在进行比较之前,我想将 loss
转换为 double
。我该怎么做?
int main() {
auto const input1(torch::randn({28*28});
auto const input2(torch::randn({28*28});
double const lossThreshold{0.05};
auto const loss{torch::nn::functional::mse_loss(input1, input2)}; // this returns an at::Tensor datatype
return loss > lossThreshold ? EXIT_FAILURE : EXIT_SUCCESS;
}
感谢 GitHub CoPilot 推荐了这个解决方案。我想我现在应该辞职了。 :(
解决方案是使用 item<T>()
模板函数,如下所示:
int main() {
auto const input1(torch::randn({28*28}); // at::Tensor
auto const input2(torch::randn({28*28}); // at::Tensor
double const lossThreshold{0.05}; // double
auto const loss{torch::nn::functional::mse_loss(input1, input2).item<double>()}; // the item<double>() converts at::Tensor to double
return loss > lossThreshold ? EXIT_FAILURE : EXIT_SUCCESS;
}
在下面的代码中,我想比较 loss
(数据类型 at::Tensor
)和 lossThreshold
(数据类型 double
)。在进行比较之前,我想将 loss
转换为 double
。我该怎么做?
int main() {
auto const input1(torch::randn({28*28});
auto const input2(torch::randn({28*28});
double const lossThreshold{0.05};
auto const loss{torch::nn::functional::mse_loss(input1, input2)}; // this returns an at::Tensor datatype
return loss > lossThreshold ? EXIT_FAILURE : EXIT_SUCCESS;
}
感谢 GitHub CoPilot 推荐了这个解决方案。我想我现在应该辞职了。 :(
解决方案是使用 item<T>()
模板函数,如下所示:
int main() {
auto const input1(torch::randn({28*28}); // at::Tensor
auto const input2(torch::randn({28*28}); // at::Tensor
double const lossThreshold{0.05}; // double
auto const loss{torch::nn::functional::mse_loss(input1, input2).item<double>()}; // the item<double>() converts at::Tensor to double
return loss > lossThreshold ? EXIT_FAILURE : EXIT_SUCCESS;
}