如何从 libtorch 中的 Tensor 取回图像?

How to get back the image from Tensor in libtorch?

我试图从我之前创建的张量中取回图像并将其可视化,但是,生成的图像似乎是 distorted/garbage。 这就是我如何将 CV_8UC3 转换为相应的 at::Tensor:

at::Tensor tensor_image = torch::from_blob(img.data, { img.rows, img.cols, 3 }, at::kByte);

这就是我将其转换回图像的方式:

auto ToCvImage(at::Tensor tensor)
{
    int width = tensor.sizes()[0];
    int height = tensor.sizes()[1];
    try
    {
        cv::Mat output_mat(cv::Size{ height, width }, CV_8UC3, tensor.data_ptr<int>());
        return output_mat.clone();
    }
    catch (const c10::Error& e)
    {
        std::cout << "an error has occured : " << e.msg() << std::endl;
    }
    return cv::Mat(height, width, CV_8UC3);
}

原图是这样的:

这是我在转换后得到的:

现在,如果我在创建张量期间使用 at::kInt 而不是 kByte

at::Tensor tensor_image = torch::from_blob(img.data, { img.rows, img.cols, 3 }, at::kByte);

我再也看不到扭曲的图像了!但是,网络输出将关闭,这意味着输入出现问题!

这里有什么问题,我应该如何处理?

当使用 c10::kByte 创建张量时,我们需要使用 uchar 而不是 charuint 等。所以为了解决这个问题,我只需要使用 uchar 而不是 int:

auto ToCvImage(at::Tensor tensor)
{
    int width = tensor.sizes()[0];
    int height = tensor.sizes()[1];
    try
    {
        cv::Mat output_mat(cv::Size{ height, width }, CV_8UC3, tensor.data_ptr<uchar>());
        return output_mat.clone();
    }
    catch (const c10::Error& e)
    {
        std::cout << "an error has occured : " << e.msg() << std::endl;
    }
    return cv::Mat(height, width, CV_8UC3);
}

旁注: 如果您使用任何其他类型创建张量,请确保有效使用 Tensor::totype() 并事先转换为正确的类型。 那是在我将这个张量输入我的网络之前,例如我将它转换为 KFloat 然后继续!这是一个很明显的问题,很可能会被忽略,并且会花费您数小时的调试时间!