Jax (google) 中是否有类似的 CUDA threadId?
Is there a CUDA threadId alike in Jax (google)?
我正在尝试了解 jax.vmap/pmap
(jax:https://jax.readthedocs.io/)的行为。 CUDA有threadId让你知道是哪个线程在执行代码,jax中有没有类似的概念? (jax.process_id
不是)
不,在 JAX 中没有与 CUDA threadid 的真正模拟。有关 GPU 线程分配的详细信息由 XLA 编译器在较低级别处理,我不知道有什么直接的 API 可以将此信息返回到 JAX 的 Python 运行时。
JAX 确实提供更高级别的设备分配处理的一种情况是使用 pmap
;在这种情况下,如果您需要依赖于执行映射代码的设备的逻辑,您可以显式地将一组设备 ID 传递给 pmapped 函数。例如,我 运行 在 8 设备系统上执行以下操作:
import jax
import jax.numpy as jnp
num_devices = jax.device_count()
def f(device, data):
return data + device
device_index = jnp.arange(num_devices)
data = jnp.zeros((num_devices, 10))
jax.pmap(f)(device_index, data)
# ShardedDeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
# [6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
# [7., 7., 7., 7., 7., 7., 7., 7., 7., 7.]], dtype=float32)
我正在尝试了解 jax.vmap/pmap
(jax:https://jax.readthedocs.io/)的行为。 CUDA有threadId让你知道是哪个线程在执行代码,jax中有没有类似的概念? (jax.process_id
不是)
不,在 JAX 中没有与 CUDA threadid 的真正模拟。有关 GPU 线程分配的详细信息由 XLA 编译器在较低级别处理,我不知道有什么直接的 API 可以将此信息返回到 JAX 的 Python 运行时。
JAX 确实提供更高级别的设备分配处理的一种情况是使用 pmap
;在这种情况下,如果您需要依赖于执行映射代码的设备的逻辑,您可以显式地将一组设备 ID 传递给 pmapped 函数。例如,我 运行 在 8 设备系统上执行以下操作:
import jax
import jax.numpy as jnp
num_devices = jax.device_count()
def f(device, data):
return data + device
device_index = jnp.arange(num_devices)
data = jnp.zeros((num_devices, 10))
jax.pmap(f)(device_index, data)
# ShardedDeviceArray([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],
# [6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],
# [7., 7., 7., 7., 7., 7., 7., 7., 7., 7.]], dtype=float32)