CublasLt cublasLtMatmulAlgoGetHeuristic returns CUBLAS_STATUS_INVALID_VALUE 对于行主矩阵

CublasLt cublasLtMatmulAlgoGetHeuristic returns CUBLAS_STATUS_INVALID_VALUE for rows major matrix

我刚刚完成重构我的程序以使用 GEMM 的 cublasLt 库,我陷入了 CUBLAS_STATUS_INVALID_VALUE在下面的函数中执行 cublasLtMatmulAlgoGetHeuristic 时。

CudaMatrix.cu:产品

/**
 * Performs the matrix-matrix multiplication C = A x B
 *
 * @see https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul
 *
 * @param A - The left matrix A
 * @param B - The right matrix B
 * @param C - The result matrix C
 * @param opA - Operation to perform on matrix A before multiplication (none, transpose or hermitian)
 * @param opB - Operation to perform on matrix B before multiplication (none, transpose or hermitian)
 * @param lightHandle - cublasLt handle
 */
template<typename precision>
void CudaMatrix<precision>::product(const CudaMatrix           &A,
                                    const CudaMatrix           &B,
                                          CudaMatrix           &C,
                                          cublasOperation_t    opA,
                                          cublasOperation_t    opB,
                                          cublasLtHandle_t     lightHandle
) {
    const precision                 zero               = 0,
                                    one                = 1;
    const int                       requestedAlgoCount = 1;
    cudaStream_t                    stream             = nullptr;
    cublasLtMatmulHeuristicResult_t heuristicResult;
    cublasLtMatmulPreference_t      preference;
    cublasLtMatmulDesc_t            computeDesc;
    int                             returnedAlgoCount;

    // Set matrix pre-operation such as transpose if any
    cublasLtCk(cublasLtMatmulDescCreate(&computeDesc, A.cublasLtDataType));
    cublasLtCk(cublasLtMatmulDescSetAttribute(computeDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(opA)));
    cublasLtCk(cublasLtMatmulDescSetAttribute(computeDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opB, sizeof(opB)));

    // Get the best algorithm to use
    cublasLtCk(cublasLtMatmulPreferenceCreate(&preference));
    cublasLtCk(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
               &CudaMatrix::matMulWorkspaceSize, sizeof(CudaMatrix::matMulWorkspaceSize)));
    cublasLtCk(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, A.matrixLayout, B.matrixLayout,
               C.matrixLayout, C.matrixLayout, preference, requestedAlgoCount, &heuristicResult, &returnedAlgoCount));

    std::cout << "returnedAlgoCount = " << returnedAlgoCount << std::endl;

    // Do the multiplication
    cublasLtCk(cublasLtMatmul(lightHandle, computeDesc, &one, A.data, A.matrixLayout, B.data, B.matrixLayout, &zero,
               C.data, C.matrixLayout, C.data, C.matrixLayout, &heuristicResult.algo,
               &CudaMatrix::matMulWorkspace, CudaMatrix::matMulWorkspaceSize, stream));

    // clean up
    cublasLtCk(cublasLtMatmulPreferenceDestroy(preference));
    cublasLtCk(cublasLtMatmulDescDestroy(computeDesc));
}

我将下面的最小可重现示例与我在程序中使用的相同源代码(带修剪)连接起来。

此错误可能与我在 NVIDIA forum 中发现的错误有关,但我不确定。

我 运行 在 Ubuntu 18.04 上使用 RTX 5000 GPU。

cublaslt_mat_mul.cu

#include <iostream>
#include <iomanip>
#include <limits>
#include <vector>
#include <cxxabi.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <cublasLt.h>

// ****************************************************************************************************************** //
//                                                    ErrorsCheck.cuh                                                 //
// ****************************************************************************************************************** //

static const char* cublasGetErrorEnum(cublasStatus_t error)
{
    switch (error)
    {
        case CUBLAS_STATUS_SUCCESS:
            return "CUBLAS_STATUS_SUCCESS";

        case CUBLAS_STATUS_NOT_INITIALIZED:
            return "CUBLAS_STATUS_NOT_INITIALIZED";

        case CUBLAS_STATUS_ALLOC_FAILED:
            return "CUBLAS_STATUS_ALLOC_FAILED";

        case CUBLAS_STATUS_INVALID_VALUE:
            return "CUBLAS_STATUS_INVALID_VALUE";

        case CUBLAS_STATUS_ARCH_MISMATCH:
            return "CUBLAS_STATUS_ARCH_MISMATCH";

        case CUBLAS_STATUS_MAPPING_ERROR:
            return "CUBLAS_STATUS_MAPPING_ERROR";

        case CUBLAS_STATUS_EXECUTION_FAILED:
            return "CUBLAS_STATUS_EXECUTION_FAILED";

        case CUBLAS_STATUS_INTERNAL_ERROR:
            return "CUBLAS_STATUS_INTERNAL_ERROR";

        case CUBLAS_STATUS_NOT_SUPPORTED:
            return "CUBLAS_STATUS_NOT_SUPPORTED";

        case CUBLAS_STATUS_LICENSE_ERROR:
            return "CUBLAS_STATUS_LICENSE_ERROR";

        default:
            return "<unknown>";
    }
}

inline void cublasLtCheck(cublasStatus_t status, int iLine, const char *szFile) {
    if (status != CUBLAS_STATUS_SUCCESS) {
        std::cerr << "CublasLt error " << cublasGetErrorEnum(status) << " at line " << iLine << " in file "
                  << szFile << std::endl;
    }
}

inline void cudaCheck(cudaError_t status, int iLine, const char *szFile) {
    if (status != cudaSuccess) {
        std::cerr << "CublasLt error " << cudaGetErrorString(status) << " at line " << iLine << " in file "
                  << szFile << std::endl;
    }
}

#define cublasLtCk(call) cublasLtCheck(call, __LINE__, __FILE__)
#define cudaCk(call) cudaCheck(call, __LINE__, __FILE__)

// ****************************************************************************************************************** //
//                                                    CudaMatrix.cuh                                                  //
// ****************************************************************************************************************** //

#define MB 1048576 // 2^19 byte

typedef unsigned int uint;

template <typename precision>
struct CudaMatrix {
    // Matrix multiplication GPU workspace that can be used to improve matrix multiplication computation time
    const static void   *matMulWorkspace;
    const static size_t matMulWorkspaceSize;

    CudaMatrix() : width(0), height(0), data(nullptr), cublasHandle(nullptr), cublasLtHandle(nullptr), matrixLayout(nullptr) { };
    CudaMatrix(uint width, uint height, cublasHandle_t cublasHandle = nullptr, cublasLtHandle_t cublasLtHandle = nullptr,
               cublasLtMatrixLayout_t matrixLayout = nullptr) : width(width), height(height), cublasHandle(cublasHandle),
               cublasLtHandle(cublasLtHandle), matrixLayout(matrixLayout)
    {
        cudaCk(cudaMalloc(&data, bytesSize()));

        if (typeid(precision).hash_code() == typeid(uint).hash_code()) {
            cublasLtDataType = CUDA_R_8U;
        } else if (typeid(precision).hash_code() == typeid(int).hash_code()) {
            cublasLtDataType = CUDA_R_8I;
        } else if (typeid(precision).hash_code() == typeid(float).hash_code()) {
            cublasLtDataType = CUDA_R_32F;
        } else if (typeid(precision).hash_code() == typeid(double).hash_code()) {
            cublasLtDataType = CUDA_R_64F;
        } else {
            throw std::runtime_error("The datatype " + std::string(typeid(precision).name()) + " is not handled in CudaMatrix");
        }

        cublasLtCk(cublasLtMatrixLayoutCreate(&matrixLayout, cublasLtDataType, height, width, width));

        if  (matMulWorkspace == nullptr) {
            cudaCk(cudaMalloc(&matMulWorkspace, matMulWorkspaceSize));
        }
    }

    __device__ __host__ uint size() const { return width * height; }

    static void product(const CudaMatrix &A, const CudaMatrix &B, CudaMatrix &C, cublasOperation_t opA, cublasOperation_t opB, cublasLtHandle_t lightHandle);

    void freeResources() { cudaCk(cudaFree(data)); cublasLtCk(cublasLtMatrixLayoutDestroy(matrixLayout)); }
    uint bytesSize() const { return size() * sizeof(precision); }
    void setValuesFromVector(const std::vector<precision> &vector);
    void setValuesFromVector(const std::vector<std::vector<precision>> &vectors);
    void display(const std::string &name = "", uint x = 0, uint y = 0, uint roiWidth = 0, uint roiHeight = 0) const;
    void product(const CudaMatrix &A) { product(*this, A, *this, CUBLAS_OP_N, CUBLAS_OP_N, cublasLtHandle); }

    precision              *data;
    uint                   width,
                           height;
    cublasHandle_t         cublasHandle;
    cublasLtHandle_t       cublasLtHandle;
    cublasLtMatrixLayout_t matrixLayout;
    cudaDataType_t         cublasLtDataType;
};

template <typename precision> const size_t CudaMatrix<precision>::matMulWorkspaceSize = 500 * MB;
template <typename precision> const void*  CudaMatrix<precision>::matMulWorkspace     = nullptr;

// ****************************************************************************************************************** //
//                                                     CudaMatrix.cu                                                  //
// ****************************************************************************************************************** //

/**
 * Display the matrix
 *
 * @tparam precision - The matrix precision
 *
 * @param name - The matrix name
 */
template <typename precision>
void CudaMatrix<precision>::display(const std::string &name, uint x, uint y, uint roiWidth, uint roiHeight) const
{
    precision *hostValues;

    roiWidth == 0 ? roiWidth = width : roiWidth = roiWidth;
    roiHeight == 0 ? roiHeight = height : roiHeight = roiHeight;

    cudaCk(cudaMallocHost(&hostValues, bytesSize()));
    cudaCk(cudaMemcpy(hostValues, data, bytesSize(), cudaMemcpyDeviceToHost));

    std::cout << std::setprecision(std::numeric_limits<precision>::digits10 + 1);

    std::cout << "Matrix " << name << " " << width << " x " << height << " pixels of "
              << abi::__cxa_demangle(typeid(precision).name(), nullptr, nullptr, nullptr)
              << "\n\n";

    for (int i = y; i < y + roiHeight; ++i) {
        std::cout << "{ ";

        for (int j = x; j < x + roiWidth - 1; ++j) {
            std::cout << *(hostValues + i * width + j) << ", ";
        }

        std::cout << *(hostValues + (i + 1) * width - 1) << " }\n";
    }

    std::cout << std::endl;

    cudaCk(cudaFreeHost(hostValues));
}

/**
 * Set the matrix values in device CUDA memory from a host standard 1D vector
 *
 * @tparam precision - The matrix precision
 *
 * @param vector - The values to set the device CUDA memory from
 */
template <typename precision>
void CudaMatrix<precision>::setValuesFromVector(const std::vector<precision> &vector)
{
    cudaCk(cudaMemcpy(data, vector.data(), vector.size() * sizeof(precision), cudaMemcpyHostToDevice));
}

/**
 * Set the matrix values in device CUDA memory from a host standard 2D vector
 *
 * @tparam precision - The matrix precision
 *
 * @param vectors - The values to set the device CUDA memory from
 */
template <typename precision>
void CudaMatrix<precision>::setValuesFromVector(const std::vector<std::vector<precision>> &vectors)
{
    std::vector<precision> buffer;

    buffer.reserve(vectors.size() * vectors[0].size());

    for (const auto &vector : vectors) {
        buffer.insert(buffer.end(), vector.begin(), vector.end());
    }

    setValuesFromVector(buffer);
}

/**
 * Performs the matrix-matrix multiplication C = A x B
 *
 * @see https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul
 *
 * @param A - The left matrix A
 * @param B - The right matrix B
 * @param C - The result matrix C
 * @param opA - Operation to perform on matrix A before multiplication (none, transpose or hermitian)
 * @param opB - Operation to perform on matrix B before multiplication (none, transpose or hermitian)
 * @param lightHandle - cublasLt handle
 */
template<typename precision>
void CudaMatrix<precision>::product(const CudaMatrix           &A,
                                    const CudaMatrix           &B,
                                          CudaMatrix           &C,
                                          cublasOperation_t    opA,
                                          cublasOperation_t    opB,
                                          cublasLtHandle_t     lightHandle
) {
    const precision                 zero               = 0,
                                    one                = 1;
    const int                       requestedAlgoCount = 1;
    cudaStream_t                    stream             = nullptr;
    cublasLtMatmulHeuristicResult_t heuristicResult;
    cublasLtMatmulPreference_t      preference;
    cublasLtMatmulDesc_t            computeDesc;
    int                             returnedAlgoCount;

    // Set matrix pre-operation such as transpose if any
    cublasLtCk(cublasLtMatmulDescCreate(&computeDesc, A.cublasLtDataType));
    cublasLtCk(cublasLtMatmulDescSetAttribute(computeDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(opA)));
    cublasLtCk(cublasLtMatmulDescSetAttribute(computeDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opB, sizeof(opB)));

    // Get the best algorithm to use
    cublasLtCk(cublasLtMatmulPreferenceCreate(&preference));
    cublasLtCk(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
               &CudaMatrix::matMulWorkspaceSize, sizeof(CudaMatrix::matMulWorkspaceSize)));
    cublasLtCk(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, A.matrixLayout, B.matrixLayout,
               C.matrixLayout, C.matrixLayout, preference, requestedAlgoCount, &heuristicResult, &returnedAlgoCount));

    std::cout << "returnedAlgoCount = " << returnedAlgoCount << std::endl;

    // Do the multiplication
    cublasLtCk(cublasLtMatmul(lightHandle, computeDesc, &one, A.data, A.matrixLayout, B.data, B.matrixLayout, &zero,
               C.data, C.matrixLayout, C.data, C.matrixLayout, &heuristicResult.algo,
               &CudaMatrix::matMulWorkspace, CudaMatrix::matMulWorkspaceSize, stream));

    // clean up
    cublasLtCk(cublasLtMatmulPreferenceDestroy(preference));
    cublasLtCk(cublasLtMatmulDescDestroy(computeDesc));
}

// Forward template declarations
template struct CudaMatrix<double>;
template struct CudaMatrix<float>;
template struct CudaMatrix<int>;
template struct CudaMatrix<uint>;

// ****************************************************************************************************************** //
//                                                        main.cu                                                     //
// ****************************************************************************************************************** //

int main(int argc, char const *argv[])
{
    cublasLtHandle_t   cublasLtHandle = nullptr;
    std::vector<float> r1Expect       = { 6, 6, 6, 15, 15, 15, 24, 24, 24 };
    std::vector<float> r2Expect       = { 1, 2, 3, 4, 5, 6, 7, 8, 9 };

    cublasLtCk(cublasLtCreate(&cublasLtHandle));

    // Declare matrices
    CudaMatrix<float> m1(3, 3);
    CudaMatrix<float> m2(3, 3);
    CudaMatrix<float> m3(3, 3);
    CudaMatrix<float> deviceResult(3, 3);

    // Set device memory values
    m1.setValuesFromVector({ {1, 1, 1}, {1, 1, 1}, {1, 1, 1} });
    m2.setValuesFromVector({ {1, 2, 3}, {4, 5, 6}, {7, 8, 9} });
    m3.setValuesFromVector({ {1, 0, 0}, {0, 1, 0}, {0, 0, 1} });

    // Test results (just showing it here)
    CudaMatrix<float>::product(m1, m2, deviceResult, CUBLAS_OP_N, CUBLAS_OP_N, cublasLtHandle);

    m1.display("m1");
    m2.display("m2");
    deviceResult.display("m1 X m2");

    CudaMatrix<float>::product(m2, m3, deviceResult, CUBLAS_OP_N, CUBLAS_OP_N, cublasLtHandle);

    m1.display("m2");
    m2.display("m3");
    deviceResult.display("m2 X m3");

    // Clean up
    cublasLtCk(cublasLtDestroy(cublasLtHandle));

    m1.freeResources();
    m2.freeResources();
    m3.freeResources();
    deviceResult.freeResources();

    return 0;
}

CMakeLists.txt

cmake_minimum_required(VERSION 3.10)
project(test-cuda)

# ------------------------------------------------ Compilation options ----------------------------------------------- #

# CUDA 10 does not support C++ 17
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
set(CMAKE_BUILD_TYPE Debug) # Release or Debug

# Include CUDA
find_package(CUDA REQUIRED)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -arch=sm_75 -std=c++14 --expt-relaxed-constexpr --expt-extended-lambda")

# ----------------------------------------------------- Constants ---------------------------------------------------- #

if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Release")
    MESSAGE(STATUS "Debug build")
    add_definitions(-DDEBUG_CUDA)
else ()
    MESSAGE(STATUS "Release build")
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
    set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3")
endif ()

# ------------------------------------------------- Source code files ------------------------------------------------ #

# All in one
file(GLOB matmul "cublaslt_mat_mul.cu")

# ---------------------------------------------------- Executables --------------------------------------------------- #

cuda_add_executable(matmulTest ${matmul})

# ---------------------------------------------------- Libraries ----------------------------------------------------- #

# Path to local libraries
file(GLOB CUDAlibs "/usr/lib/x86_64-linux-gnu/libcuda.so" "/usr/lib/x86_64-linux-gnu/libcublas.so" "/usr/lib/x86_64-linux-gnu/libcublasLt.so" "/usr/local/cuda/lib64/libcudart.so")
# Link libraries
target_link_libraries(matmulTest ${CUDAlibs})

输出

CublasLt error CUBLAS_STATUS_INVALID_VALUE at line 249 in file /home/rom1/Desktop/test_cuda/cublaslt_mat_mul.cu
returnedAlgoCount = -768202864
CublasLt error CUBLAS_STATUS_INVALID_VALUE at line 256 in file /home/rom1/Desktop/test_cuda/cublaslt_mat_mul.cu
Matrix m1 3 x 3 pixels of float

{ 1, 1, 1 }
{ 1, 1, 1 }
{ 1, 1, 1 }

Matrix m2 3 x 3 pixels of float

{ 1, 2, 3 }
{ 4, 5, 6 }
{ 7, 8, 9 }

Matrix m1 X m2 3 x 3 pixels of float

{ 0, 0, 0 }
{ 0, 0, 0 }
{ 0, 0, 0 }

CublasLt error CUBLAS_STATUS_INVALID_VALUE at line 249 in file /home/rom1/Desktop/test_cuda/cublaslt_mat_mul.cu
returnedAlgoCount = -870514560
CublasLt error CUBLAS_STATUS_INVALID_VALUE at line 256 in file /home/rom1/Desktop/test_cuda/cublaslt_mat_mul.cu
Matrix m2 3 x 3 pixels of float

{ 1, 1, 1 }
{ 1, 1, 1 }
{ 1, 1, 1 }

Matrix m3 3 x 3 pixels of float

{ 1, 0, 0 }
{ 0, 1, 0 }
{ 0, 0, 1 }

Matrix m2 X m3 3 x 3 pixels of float

{ 0, 0, 0 }
{ 0, 0, 0 }
{ 0, 0, 0 }

我犯了 2 个错误

matrixLayout没有设置好,我写了一个函数在每次乘法之前写它基于op应用于矩阵。

另外我把矩阵记忆放在专业而不是专业。

现在代码对于方形和非方形乘积以及行主存储器运行良好。

cublaslt_mat_mul.cu

#include <iostream>
#include <iomanip>
#include <limits>
#include <vector>
#include <cxxabi.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <cublasLt.h>

// ****************************************************************************************************************** //
//                                                    ErrorsCheck.cuh                                                 //
// ****************************************************************************************************************** //

static const char* cublasGetErrorEnum(cublasStatus_t error)
{
    switch (error)
    {
        case CUBLAS_STATUS_SUCCESS:
            return "CUBLAS_STATUS_SUCCESS";

        case CUBLAS_STATUS_NOT_INITIALIZED:
            return "CUBLAS_STATUS_NOT_INITIALIZED";

        case CUBLAS_STATUS_ALLOC_FAILED:
            return "CUBLAS_STATUS_ALLOC_FAILED";

        case CUBLAS_STATUS_INVALID_VALUE:
            return "CUBLAS_STATUS_INVALID_VALUE";

        case CUBLAS_STATUS_ARCH_MISMATCH:
            return "CUBLAS_STATUS_ARCH_MISMATCH";

        case CUBLAS_STATUS_MAPPING_ERROR:
            return "CUBLAS_STATUS_MAPPING_ERROR";

        case CUBLAS_STATUS_EXECUTION_FAILED:
            return "CUBLAS_STATUS_EXECUTION_FAILED";

        case CUBLAS_STATUS_INTERNAL_ERROR:
            return "CUBLAS_STATUS_INTERNAL_ERROR";

        case CUBLAS_STATUS_NOT_SUPPORTED:
            return "CUBLAS_STATUS_NOT_SUPPORTED";

        case CUBLAS_STATUS_LICENSE_ERROR:
            return "CUBLAS_STATUS_LICENSE_ERROR";

        default:
            return "<unknown>";
    }
}

inline void cublasLtCheck(cublasStatus_t status, int iLine, const char *szFile) {
    if (status != CUBLAS_STATUS_SUCCESS) {
        std::cerr << "CublasLt error " << cublasGetErrorEnum(status) << " at line " << iLine << " in file "
                  << szFile << std::endl;
    }
}

inline void cudaCheck(cudaError_t status, int iLine, const char *szFile) {
    if (status != cudaSuccess) {
        std::cerr << "CublasLt error " << cudaGetErrorString(status) << " at line " << iLine << " in file "
                  << szFile << std::endl;
    }
}

#define cublasLtCk(call) cublasLtCheck(call, __LINE__, __FILE__)
#define cudaCk(call) cudaCheck(call, __LINE__, __FILE__)

// ****************************************************************************************************************** //
//                                                    CudaMatrix.cuh                                                  //
// ****************************************************************************************************************** //

#define MB 1048576 // 2^19 byte

typedef unsigned int uint;

template <typename precision>
struct CudaMatrix {
    // Matrix multiplication GPU workspace that can be used to improve matrix multiplication computation time
    const static void   *matMulWorkspace;
    const static size_t matMulWorkspaceSize;

    CudaMatrix() : width(0), height(0), data(nullptr), cublasHandle(nullptr), cublasLtHandle(nullptr), matrixLayout(nullptr) { };
    CudaMatrix(uint width, uint height, cublasHandle_t cublasHandle = nullptr, cublasLtHandle_t cublasLtHandle = nullptr,
               cublasLtMatrixLayout_t matrixLayout = nullptr) : width(width), height(height), cublasHandle(cublasHandle),
               cublasLtHandle(cublasLtHandle), matrixLayout(matrixLayout)
    {
        cudaCk(cudaMalloc(&data, bytesSize()));

        if (typeid(precision).hash_code() == typeid(uint).hash_code()) {
            cublasLtDataType = CUDA_R_8U;
        } else if (typeid(precision).hash_code() == typeid(int).hash_code()) {
            cublasLtDataType = CUDA_R_8I;
        } else if (typeid(precision).hash_code() == typeid(float).hash_code()) {
            cublasLtDataType = CUDA_R_32F;
        } else if (typeid(precision).hash_code() == typeid(double).hash_code()) {
            cublasLtDataType = CUDA_R_64F;
        } else {
            throw std::runtime_error("The datatype " + std::string(typeid(precision).name()) + " is not handled in CudaMatrix");
        }

        if  (matMulWorkspace == nullptr) {
            cudaCk(cudaMalloc(&matMulWorkspace, matMulWorkspaceSize));
        }
    }

    __device__ __host__ uint size() const { return width * height; }

    static void product(CudaMatrix &A, CudaMatrix &B, CudaMatrix &C, cublasOperation_t opA, cublasOperation_t opB, cublasLtHandle_t lightHandle);

    void freeResources() { cudaCk(cudaFree(data)); cublasLtCk(cublasLtMatrixLayoutDestroy(matrixLayout)); }
    void setMatrixLayout(cublasOperation_t op, cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW);
    uint bytesSize() const { return size() * sizeof(precision); }
    void setValuesFromVector(const std::vector<precision> &vector);
    void setValuesFromVector(const std::vector<std::vector<precision>> &vectors);
    void display(const std::string &name = "", uint x = 0, uint y = 0, uint roiWidth = 0, uint roiHeight = 0) const;
    void product(CudaMatrix &A) { product(*this, A, *this, CUBLAS_OP_N, CUBLAS_OP_N, cublasLtHandle); }

    precision              *data;
    uint                   width,
                           height;
    cublasHandle_t         cublasHandle;
    cublasLtHandle_t       cublasLtHandle;
    cublasLtMatrixLayout_t matrixLayout;
    cudaDataType_t         cublasLtDataType;
};

template <typename precision> const size_t CudaMatrix<precision>::matMulWorkspaceSize = 500 * MB;
template <typename precision> const void*  CudaMatrix<precision>::matMulWorkspace     = nullptr;

// ****************************************************************************************************************** //
//                                                     CudaMatrix.cu                                                  //
// ****************************************************************************************************************** //

/**
 * Display the matrix
 *
 * @tparam precision - The matrix precision
 *
 * @param name - The matrix name
 */
template <typename precision>
void CudaMatrix<precision>::display(const std::string &name, uint x, uint y, uint roiWidth, uint roiHeight) const
{
    precision *hostValues;

    roiWidth == 0 ? roiWidth = width : roiWidth = roiWidth;
    roiHeight == 0 ? roiHeight = height : roiHeight = roiHeight;

    cudaCk(cudaMallocHost(&hostValues, bytesSize()));
    cudaCk(cudaMemcpy(hostValues, data, bytesSize(), cudaMemcpyDeviceToHost));

    std::cout << std::setprecision(std::numeric_limits<precision>::digits10 + 1);

    std::cout << "Matrix " << name << " " << width << " x " << height << " pixels of "
              << abi::__cxa_demangle(typeid(precision).name(), nullptr, nullptr, nullptr)
              << "\n\n";

    for (int i = y; i < y + roiHeight; ++i) {
        std::cout << "{ ";

        for (int j = x; j < x + roiWidth - 1; ++j) {
            std::cout << *(hostValues + i * width + j) << ", ";
        }

        std::cout << *(hostValues + (i + 1) * width - 1) << " }\n";
    }

    std::cout << std::endl;

    cudaCk(cudaFreeHost(hostValues));
}

/**
 * Set the matrix values in device CUDA memory from a host standard 1D vector
 *
 * @tparam precision - The matrix precision
 *
 * @param vector - The values to set the device CUDA memory from
 */
template <typename precision>
void CudaMatrix<precision>::setValuesFromVector(const std::vector<precision> &vector)
{
    cudaCk(cudaMemcpy(data, vector.data(), vector.size() * sizeof(precision), cudaMemcpyHostToDevice));
}

/**
 * Set the matrix values in device CUDA memory from a host standard 2D vector
 *
 * @tparam precision - The matrix precision
 *
 * @param vectors - The values to set the device CUDA memory from
 */
template <typename precision>
void CudaMatrix<precision>::setValuesFromVector(const std::vector<std::vector<precision>> &vectors)
{
    std::vector<precision> buffer;

    buffer.reserve(vectors.size() * vectors[0].size());

    for (const auto &vector : vectors) {
        buffer.insert(buffer.end(), vector.begin(), vector.end());
    }

    setValuesFromVector(buffer);
}

/**
 * Set the matrix layout before matrix multiplication with row major memory by default
 *
 * @tparam precision - The matrix precision
 *
 * @param op - Operation to perform on matrix before multiplication (none, transpose or hermitian)
 * @param matrixOrder - The matrix memory order (column or row DEFAULT row)
 */
template<typename precision>
void CudaMatrix<precision>:: setMatrixLayout(cublasOperation_t op, cublasLtOrder_t matrixOrder)
{
    const uint m = (op == CUBLAS_OP_N ? height : width),
               n = (op == CUBLAS_OP_N ? width : height);

    cublasLtCk(cublasLtMatrixLayoutCreate(&matrixLayout, cublasLtDataType, m, n, height));
    cublasLtCk(cublasLtMatrixLayoutSetAttribute(matrixLayout, CUBLASLT_MATRIX_LAYOUT_ORDER, &matrixOrder, sizeof(matrixOrder)));
}

/**
 * Performs the matrix-matrix multiplication C = A x B
 *
 * @see https://docs.nvidia.com/cuda/cublas/index.html#cublasLtMatmul
 *
 * @param A - The left matrix A
 * @param B - The right matrix B
 * @param C - The result matrix C
 * @param opA - Operation to perform on matrix A before multiplication (none, transpose or hermitian)
 * @param opB - Operation to perform on matrix B before multiplication (none, transpose or hermitian)
 * @param lightHandle - cublasLt handle
 */
template<typename precision>
void CudaMatrix<precision>::product(CudaMatrix           &A,
                                    CudaMatrix           &B,
                                    CudaMatrix           &C,
                                    cublasOperation_t    opA,
                                    cublasOperation_t    opB,
                                    cublasLtHandle_t     lightHandle
) {
    const precision                 zero               = 0,
                                    one                = 1;
    const int                       requestedAlgoCount = 1;
    cudaStream_t                    stream             = nullptr;
    cublasLtMatmulHeuristicResult_t heuristicResult;
    cublasLtMatmulPreference_t      preference;
    cublasLtMatmulDesc_t            computeDesc;
    int                             returnedAlgoCount;

    // Set matrix pre-operation such as transpose if any
    cublasLtCk(cublasLtMatmulDescCreate(&computeDesc, A.cublasLtDataType));
    cublasLtCk(cublasLtMatmulDescSetAttribute(computeDesc, CUBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(opA)));
    cublasLtCk(cublasLtMatmulDescSetAttribute(computeDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opB, sizeof(opB)));

    // Set matrices layout
    A.setMatrixLayout(opA);
    B.setMatrixLayout(opB);
    C.setMatrixLayout(CUBLAS_OP_N);

    // Get the best algorithm to use
    cublasLtCk(cublasLtMatmulPreferenceCreate(&preference));
    cublasLtCk(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
               &CudaMatrix::matMulWorkspaceSize, sizeof(CudaMatrix::matMulWorkspaceSize)));
    cublasLtCk(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, A.matrixLayout, B.matrixLayout,
               C.matrixLayout, C.matrixLayout, preference, requestedAlgoCount, &heuristicResult, &returnedAlgoCount));

    // Do the multiplication
    cublasLtCk(cublasLtMatmul(lightHandle, computeDesc, &one, A.data, A.matrixLayout, B.data, B.matrixLayout, &zero,
               C.data, C.matrixLayout, C.data, C.matrixLayout, &heuristicResult.algo,
               &CudaMatrix::matMulWorkspace, CudaMatrix::matMulWorkspaceSize, stream));

    // clean up
    cublasLtCk(cublasLtMatmulPreferenceDestroy(preference));
    cublasLtCk(cublasLtMatmulDescDestroy(computeDesc));
}

// Forward template declarations
template struct CudaMatrix<double>;
template struct CudaMatrix<float>;
template struct CudaMatrix<int>;
template struct CudaMatrix<uint>;

// ****************************************************************************************************************** //
//                                                        main.cu                                                     //
// ****************************************************************************************************************** //

int main(int argc, char const *argv[])
{
    cublasLtHandle_t   cublasLtHandle = nullptr;
    std::vector<float> r1Expect       = { 6, 6, 6, 15, 15, 15, 24, 24, 24 };
    std::vector<float> r2Expect       = { 1, 2, 3, 4, 5, 6, 7, 8, 9 };

    cublasLtCk(cublasLtCreate(&cublasLtHandle));

    // Declare matrices
    CudaMatrix<float> m1(3, 3);
    CudaMatrix<float> m2(3, 3);
    CudaMatrix<float> m3(3, 3);
    CudaMatrix<float> m4(3, 2);
    CudaMatrix<float> m5(2, 3);
    CudaMatrix<float> deviceResult_2_2(2, 2);
    CudaMatrix<float> deviceResult_3_3(3, 3);

    // Set device memory values
    m1.setValuesFromVector({ {1, 1, 1}, {1, 1, 1}, {1, 1, 1} });
    m2.setValuesFromVector({ {1, 2, 3}, {4, 5, 6}, {7, 8, 9} });
    m3.setValuesFromVector({ {1, 0, 0}, {0, 1, 0}, {0, 0, 1} });
    m4.setValuesFromVector({ {1, 2, 3}, {4, 5, 6} });
    m5.setValuesFromVector({ {1, 2}, { 3, 4 }, { 5 , 6 } });

    // Test results (just showing it here)
    CudaMatrix<float>::product(m1, m2, deviceResult_3_3, CUBLAS_OP_N, CUBLAS_OP_N, cublasLtHandle);

    deviceResult_3_3.display("m1 X m2");

    CudaMatrix<float>::product(m2, m3, deviceResult_3_3, CUBLAS_OP_N, CUBLAS_OP_N, cublasLtHandle);

    deviceResult_3_3.display("m2 X m3");

    CudaMatrix<float>::product(m4, m5, deviceResult_3_3, CUBLAS_OP_N, CUBLAS_OP_N, cublasLtHandle);

    deviceResult_3_3.display("m4 X m5");

    CudaMatrix<float>::product(m5, m4, deviceResult_2_2, CUBLAS_OP_N, CUBLAS_OP_N, cublasLtHandle);

    deviceResult_2_2.display("m5 X m4");

    // Clean up
    cublasLtCk(cublasLtDestroy(cublasLtHandle));

    m1.freeResources();
    m2.freeResources();
    m3.freeResources();
    deviceResult_2_2.freeResources();
    deviceResult_3_3.freeResources();

    return 0;
}