JAX(XLA)与 Numba(LLVM)减少
JAX(XLA) vs Numba(LLVM) Reduction
是否有可能 CPU 使用 JAX 仅在计算时间方面与 Numba 相媲美?
编译器直接来自conda
:
$ conda install -c conda-forge numba jax
这是一个一维 NumPy 数组示例
import numpy as np
import numba as nb
import jax as jx
@nb.njit
def reduce_1d_njit_serial(x):
s = 0
for xi in x:
s += xi
return s
@jx.jit
def reduce_1d_jax_serial(x):
s = 0
for xi in x:
s += xi
return s
N = 2**10
a = np.random.randn(N)
在以下
上使用timeit
np.add.reduce(a)
给出 1.99 µs ...
reduce_1d_njit_serial(a)
给出 1.43 µs ...
reduce_1d_jax_serial(a).item()
给出 23.5 µs ...
请注意 jx.numpy.sum(a)
并使用 jx.lax.fori_loop
给出了可比较的(稍微慢一些)比较。次 reduce_1d_jax_serial
.
似乎有更好的方法来减少 XLA。
编辑:编译时间未包括在内,因为打印语句继续检查结果。
在使用 JAX 执行这些类型的微基准测试时,您必须小心以确保您正在衡量您认为正在衡量的内容。 JAX Benchmarking FAQ中有一些提示。实施其中一些最佳实践后,我发现以下内容可作为您的基准:
import jax.numpy as jnp
# Native jit-compiled XLA sum
jit_sum = jx.jit(jnp.sum)
# Avoid including device transfer cost in the benchmarks
a_jax = jnp.array(a)
# Prevent measuring compilation time
_ = reduce_1d_njit_serial(a)
_ = reduce_1d_jax_serial(a_jax)
_ = jit_sum(a_jax)
%timeit np.add.reduce(a)
# 100000 loops, best of 5: 2.33 µs per loop
%timeit reduce_1d_njit_serial(a)
# 1000000 loops, best of 5: 1.43 µs per loop
%timeit reduce_1d_jax_serial(a_jax).block_until_ready()
# 100000 loops, best of 5: 6.24 µs per loop
%timeit jit_sum(a_jax).block_until_ready()
# 100000 loops, best of 5: 4.37 µs per loop
您会发现,对于这些微基准测试,JAX 比 numpy 和 numba 都慢几毫秒。那么这是否意味着 JAX 很慢?是与否;您会在 JAX FAQ: is JAX faster than numpy? 中找到该问题的更完整答案。简短的总结是这个计算非常小,差异主要是 Python 调度时间而不是在数组上操作所花费的时间。 JAX 项目没有投入太多精力来优化 Python 微基准测试的调度:这在实践中并不是那么重要,因为在 JAX 中每个程序都会产生一次成本,而不是在 numpy 中每个操作产生一次成本。
是否有可能 CPU 使用 JAX 仅在计算时间方面与 Numba 相媲美?
编译器直接来自conda
:
$ conda install -c conda-forge numba jax
这是一个一维 NumPy 数组示例
import numpy as np
import numba as nb
import jax as jx
@nb.njit
def reduce_1d_njit_serial(x):
s = 0
for xi in x:
s += xi
return s
@jx.jit
def reduce_1d_jax_serial(x):
s = 0
for xi in x:
s += xi
return s
N = 2**10
a = np.random.randn(N)
在以下
上使用timeit
np.add.reduce(a)
给出1.99 µs ...
reduce_1d_njit_serial(a)
给出1.43 µs ...
reduce_1d_jax_serial(a).item()
给出23.5 µs ...
请注意 jx.numpy.sum(a)
并使用 jx.lax.fori_loop
给出了可比较的(稍微慢一些)比较。次 reduce_1d_jax_serial
.
似乎有更好的方法来减少 XLA。
编辑:编译时间未包括在内,因为打印语句继续检查结果。
在使用 JAX 执行这些类型的微基准测试时,您必须小心以确保您正在衡量您认为正在衡量的内容。 JAX Benchmarking FAQ中有一些提示。实施其中一些最佳实践后,我发现以下内容可作为您的基准:
import jax.numpy as jnp
# Native jit-compiled XLA sum
jit_sum = jx.jit(jnp.sum)
# Avoid including device transfer cost in the benchmarks
a_jax = jnp.array(a)
# Prevent measuring compilation time
_ = reduce_1d_njit_serial(a)
_ = reduce_1d_jax_serial(a_jax)
_ = jit_sum(a_jax)
%timeit np.add.reduce(a)
# 100000 loops, best of 5: 2.33 µs per loop
%timeit reduce_1d_njit_serial(a)
# 1000000 loops, best of 5: 1.43 µs per loop
%timeit reduce_1d_jax_serial(a_jax).block_until_ready()
# 100000 loops, best of 5: 6.24 µs per loop
%timeit jit_sum(a_jax).block_until_ready()
# 100000 loops, best of 5: 4.37 µs per loop
您会发现,对于这些微基准测试,JAX 比 numpy 和 numba 都慢几毫秒。那么这是否意味着 JAX 很慢?是与否;您会在 JAX FAQ: is JAX faster than numpy? 中找到该问题的更完整答案。简短的总结是这个计算非常小,差异主要是 Python 调度时间而不是在数组上操作所花费的时间。 JAX 项目没有投入太多精力来优化 Python 微基准测试的调度:这在实践中并不是那么重要,因为在 JAX 中每个程序都会产生一次成本,而不是在 numpy 中每个操作产生一次成本。