在 Metal 中同步网格中的所有线程

Synchronizing all threads in a grid in Metal

我正在尝试在 Metal 中为 n 大小的向量编写范数或平方长度函数。为此,我计划让每个线程对每个元素进行平方,然后选择一个线程对所有元素求和。

这是我当前的内核:

#include <metal_stdlib>
#include <metal_compute>
using namespace metal;

kernel void length_squared(const device float *x [[ buffer(0) ]],
                           device float *s [[ buffer(1) ]],
                           device float *out [[ buffer(2) ]],
                           uint gid [[ thread_position_in_grid ]],
                           uint numElements [[ threads_per_grid ]])
{
    s[gid] = x[gid];// * x[gid];
    simdgroup_barrier(mem_flags::mem_none);
    if(gid == 0){
        for(uint i = 0; i < numElements; i++){
            *out += s[i];
        }
    }
}

不幸的是,此代码无法编译,"Use of Undeclared Identifier simdgroup_barrier"。该方法记录在 Metal Shading Language Specification 中。

有人遇到过这种情况吗?或者知道如何同步网格中的所有线程? threadgroup_barrier 没有为我实现完全同步。

我是否错误地处理了这个问题?同步此操作的最佳方法是什么?

SIMD 组小于线程组,因此无法进行同步。

相反,您需要使用 parallel reduction to sum up the values in parallel. Here 是我找到的一些 Metal 代码。

不过,如果您不介意一个线程完成所有求和,您可以 运行 一个单独的内核,只有一个线程来完成求和。当然,这可能会很慢。