JAX(XLA)与 Numba(LLVM)减少

JAX(XLA) vs Numba(LLVM) Reduction

是否有可能 CPU 使用 JAX 仅在计算时间方面与 Numba 相媲美?

编译器直接来自conda:

$ conda install -c conda-forge numba jax

这是一个一维 NumPy 数组示例

import numpy as np
import numba as nb
import jax as jx

@nb.njit
def reduce_1d_njit_serial(x):
    s = 0
    for xi in x:
        s += xi
    return s

@jx.jit
def reduce_1d_jax_serial(x):
    s = 0
    for xi in x:
        s += xi
    return s

N = 2**10
a = np.random.randn(N)

在以下

上使用timeit
  1. np.add.reduce(a) 给出 1.99 µs ...
  2. reduce_1d_njit_serial(a) 给出 1.43 µs ...
  3. reduce_1d_jax_serial(a).item() 给出 23.5 µs ...

请注意 jx.numpy.sum(a) 并使用 jx.lax.fori_loop 给出了可比较的(稍微慢一些)比较。次 reduce_1d_jax_serial.

似乎有更好的方法来减少 XLA。

编辑:编译时间未包括在内,因为打印语句继续检查结果。

在使用 JAX 执行这些类型的微基准测试时,您必须小心以确保您正在衡量您认为正在衡量的内容。 JAX Benchmarking FAQ中有一些提示。实施其中一些最佳实践后,我发现以下内容可作为您的基准:

import jax.numpy as jnp

# Native jit-compiled XLA sum
jit_sum = jx.jit(jnp.sum)

# Avoid including device transfer cost in the benchmarks
a_jax = jnp.array(a)

# Prevent measuring compilation time
_ = reduce_1d_njit_serial(a)
_ = reduce_1d_jax_serial(a_jax)
_ = jit_sum(a_jax)

%timeit np.add.reduce(a)
# 100000 loops, best of 5: 2.33 µs per loop

%timeit reduce_1d_njit_serial(a)
# 1000000 loops, best of 5: 1.43 µs per loop

%timeit reduce_1d_jax_serial(a_jax).block_until_ready()
# 100000 loops, best of 5: 6.24 µs per loop

%timeit jit_sum(a_jax).block_until_ready()
# 100000 loops, best of 5: 4.37 µs per loop

您会发现,对于这些微基准测试,JAX 比 numpy 和 numba 都慢几毫秒。那么这是否意味着 JAX 很慢?是与否;您会在 JAX FAQ: is JAX faster than numpy? 中找到该问题的更完整答案。简短的总结是这个计算非常小,差异主要是 Python 调度时间而不是在数组上操作所花费的时间。 JAX 项目没有投入太多精力来优化 Python 微基准测试的调度:这在实践中并不是那么重要,因为在 JAX 中每个程序都会产生一次成本,而不是在 numpy 中每个操作产生一次成本。