如何在 CUDA 中使用同步线程进行扫描算法 (Hillis-Steele)

How to use syncthreads in CUDA for a scan algorithm (Hillis-Steele)

我正在尝试实施扫描算法 (Hillis-Steele),但我在理解如何在 CUDA 上正确执行时遇到了一些问题。这是使用 pyCUDA 的最小示例:

import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
from pycuda.compiler import SourceModule

#compile cuda code
mod = SourceModule('''
__global__ void scan(int * addresses){
    for(int idx=1; idx <= threadIdx.x; idx <<= 1){
        int new_value = addresses[threadIdx.x] + addresses[threadIdx.x - idx];
        __syncthreads();
        addresses[threadIdx.x] = new_value;
    }    
}
''')
func = mod.get_function("scan")

#Initialize an array with 1's
addresses_h = np.full((896,), 1, dtype='i4')
addresses_d = cuda.to_device(addresses_h)

#Launch kernel and copy back the result
threads_x = 896
func(addresses_d, block=(threads_x, 1, 1), grid=(1, 1))
addresses_h = cuda.from_device(addresses_d, addresses_h.shape, addresses_h.dtype)

# Check the result is correct
for i, n in enumerate(addresses_h):
    assert i+1 == n

我的问题是关于 __syncthreads()。如您所见,我在 for 循环中调用 __syncthreads() 并且并非每个线程都会执行相同次数的代码:

ThreadID - Times it will execute for loop
       0 :  0 times
       1 :  1 times
   2-  3 :  2 times
   4-  7 :  3 times
   8- 15 :  4 times
  16- 31 :  5 times
  32- 63 :  6 times
  64-127 :  7 times
 128-255 :  8 times
 256-511 :  9 times
 512-896 : 10 times

同一个 warp 中可以有多个线程调用 syncthreads 的次数不同。在那种情况下会发生什么?不执行相同代码的线程如何同步?

在示例代码中,我们从一个全为 1 的数组开始,在输出中得到索引+1 作为每个位置的值。它正在计算正确答案。是"by chance"还是代码正确?

如果这不是对同步线程的正确使用,我如何使用 cuda 实现这样的算法?

If this is not a proper use of syncthreads, how could I implement such algoritm using cuda?

一种典型的方法是将条件代码与 __syncthreads() 调用分开。使用条件代码确定哪些线程将参与。

这是您发布的代码的简单转换,应该会产生相同的结果,没有任何违规(即所有线程都将参与每个 __syncthreads() 操作):

mod = SourceModule('''
__global__ void scan(int * addresses){
    for(int i=1; i < blockDim.x; i <<= 1){
        int new_value;
        if (threadIdx.x >= i) new_value = addresses[threadIdx.x] + addresses[threadIdx.x - i];
        __syncthreads();
        if (threadIdx.x >= i) addresses[threadIdx.x] = new_value;
    }    
}
''')

我并不是说这是一个完整和正确的扫描,或者它是最佳的,或者任何类似的东西。我只是展示如何转换您的代码,以避免您所拥有的内容中固有的违规行为。

如果您想了解更多扫描方法,this is one source. But if you actually need a scan operation, I would suggest using thrust or cub