将 CuPy CUDA 设备指针传递给 pybind11

Passing CuPy CUDA device pointer to pybind11

我正在尝试使用 CuPy 在 GPU 内存中实例化一个数组,然后使用 pybind11 将指向该数组的指针传递给 C++。

下面显示了我 运行 遇到的问题的最小示例。

Python

import demolib #compiled pybind11 library
import cupy as cp

x = cp.ones(100000)
y = cp.ones(100000)

demolib.pyadd(len(x),x.data.ptr,y.data.ptr)

C++/CUDA

#include <iostream>
#include <math.h>
#include <cuda_runtime.h>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>

namespace py = pybind11;

// Error Checking Function
#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
{
   if (code != cudaSuccess)
   {
      fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
      if (abort) exit(code);
   }
}

// Simple CUDA kernel
__global__
void cuadd(int n, float *x, float *y)
{
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int stride = blockDim.x * gridDim.x;
  for (int i = index; i < n; i += stride)
    y[i] = x[i] + y[i];
}

// Simple wrapper function to be exposed to Python
int pyadd(int N, float *x, float *y)
{

  // Run kernel on 1M elements on the GPU
  int blockSize = 256;
  int numBlocks = (N + blockSize - 1) / blockSize;
  cuadd<<<numBlocks, blockSize>>>(N,x,y);

  // Wait for GPU to finish before accessing on host
  gpuErrchk( cudaPeekAtLastError() );
  gpuErrchk( cudaDeviceSynchronize() );

  return 0;
}

PYBIND11_MODULE(demolib, m) {
        m.doc() = "pybind11 example plugin"; // optional module docstring
        m.def("pyadd", &pyadd, "A function which adds two numbers");
}

代码抛出以下错误:

GPUassert: an illegal memory access was encountered /home/tbm/cuda/add_pybind.cu 47

我意识到这个具体示例可以使用 cupy user defined kernel 来实现,但最终目标是能够将 cupy 数组零复制传递到更大的代码库中,这将禁止重写在这个范例中。

我也找到了这个 GitHub Issue,这与我正在尝试做的相反。

修复方法是将 pyadd 的参数类型更改为 int 并将 int 转换为 float 指针,如下所示。正如评论中所指出的,这是通过引用另一个 question 来解决的。(发布时未回答)

int pyadd(int N, long px, long py)
{

  float *x = reinterpret_cast<float*> (px);
  float *y = reinterpret_cast<float*> (py);

.
.
.