Jax (google) 中是否有类似的 CUDA threadId?

Is there a CUDA threadId alike in Jax (google)?

我正在尝试了解 jax.vmap/pmap(jax:https://jax.readthedocs.io/)的行为。 CUDA有threadId让你知道是哪个线程在执行代码,jax中有没有类似的概念? (jax.process_id 不是)

不,在 JAX 中没有与 CUDA threadid 的真正模拟。有关 GPU 线程分配的详细信息由 XLA 编译器在较低级别处理,我不知道有什么直接的 API 可以将此信息返回到 JAX 的 Python 运行时。

JAX 确实提供更高级别的设备分配处理的一种情况是使用 pmap;在这种情况下,如果您需要依赖于执行映射代码的设备的逻辑,您可以显式地将一组设备 ID 传递给 pmapped 函数。例如,我 运行 在 8 设备系统上执行以下操作:

import jax
import jax.numpy as jnp

num_devices = jax.device_count()

def f(device, data):
  return data + device

device_index = jnp.arange(num_devices)
data = jnp.zeros((num_devices, 10))

jax.pmap(f)(device_index, data)

# ShardedDeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
#                     [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
#                     [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
#                     [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
#                     [4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
#                     [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
#                     [6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
#                     [7., 7., 7., 7., 7., 7., 7., 7., 7., 7.]], dtype=float32)