有没有一种方法可以加快使用 JAX 对向量进行索引的速度?
Is there a way to speed up indexing a vector with JAX?
我正在索引向量并使用 JAX,但我注意到在简单索引数组时与 numpy 相比速度明显变慢。例如,考虑在 JAX numpy 和普通 numpy 中制作一个基本数组:
import jax.numpy as jnp
import numpy as onp
jax_array = jnp.ones((1000,))
numpy_array = onp.ones(1000)
然后简单地在两个整数之间建立索引,对于 JAX(在 GPU 上)这给出了时间:
%timeit jax_array[435:852]
1000 loops, best of 5: 1.38 ms per loop
对于 numpy,这给出了时间:
%timeit numpy_array[435:852]
1000000 loops, best of 5: 271 ns per loop
所以 numpy 比 JAX 快 5000 倍。当 JAX 在 CPU 上时,则
%timeit jax_array[435:852]
1000 loops, best of 5: 577 µs per loop
如此之快,但仍然比 numpy 慢 2000 倍。我为此使用 Google Colab 笔记本,因此 installation/CUDA.
应该没有问题
我错过了什么吗?我意识到索引对于 JAX 和 numpy 是不同的,正如 the JAX 'sharp edges' documentation 给出的那样,但是我找不到任何方法来执行赋值,例如
new_array = jax_array[435:852]
没有明显放缓。我无法避免对数组进行索引,因为这在我的程序中是必要的。
简短的回答:要加快 JAX 的速度,请使用 jit
。
长答案:
您通常应该期望在 op-by-op 模式下使用 JAX 的单个操作比 numpy 中的类似操作慢。这是因为 JAX 执行有一些固定的每个 python 函数调用开销涉及将编译下推到 XLA。
即使像索引这样看似简单的操作也是根据多个 XLA 操作实现的,其中(在 JIT 之外)每个操作都会增加自己的调用开销。您可以使用 make_jaxpr
转换来查看此序列,以检查函数是如何用原始操作表示的:
from jax import make_jaxpr
f = lambda x: x[435:852]
make_jaxpr(f)(jax_array)
# { lambda ; a.
# let b = broadcast_in_dim[ broadcast_dimensions=( )
# shape=(1,) ] 435
# c = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))
# indices_are_sorted=True
# slice_sizes=(417,)
# unique_indices=True ] a b
# d = broadcast_in_dim[ broadcast_dimensions=(0,)
# shape=(417,) ] c
# in (d,) }
(有关如何阅读此内容的信息,请参阅 Understanding Jaxprs)。
JAX 优于 numpy 的地方不在于单个小操作(其中 JAX 调度开销占主导地位),而是在于通过 jit
转换编译的操作序列。因此,例如,比较索引的 JIT 编译版本和非 JIT 编译版本:
%timeit f(jax_array).block_until_ready()
# 1000 loops, best of 5: 612 µs per loop
f_jit = jit(f)
f_jit(jax_array) # trigger compilation
%timeit f_jit(jax_array).block_until_ready()
# 100000 loops, best of 5: 4.34 µs per loop
(请注意,由于 JAX 的 asynchronous dispatch,准确的微基准测试需要 block_until_ready()
)
JIT 编译此代码可提供 150 倍的加速。由于 JAX 的几毫秒分派开销,它仍然没有 numpy 快,但是对于 JIT,这种开销只会产生一次。当您将微基准测试转移到更复杂的现实世界计算序列时,那几毫秒将不再占主导地位,并且 XLA 编译器提供的优化可以使 JAX 比等效的 numpy 计算快得多。
我正在索引向量并使用 JAX,但我注意到在简单索引数组时与 numpy 相比速度明显变慢。例如,考虑在 JAX numpy 和普通 numpy 中制作一个基本数组:
import jax.numpy as jnp
import numpy as onp
jax_array = jnp.ones((1000,))
numpy_array = onp.ones(1000)
然后简单地在两个整数之间建立索引,对于 JAX(在 GPU 上)这给出了时间:
%timeit jax_array[435:852]
1000 loops, best of 5: 1.38 ms per loop
对于 numpy,这给出了时间:
%timeit numpy_array[435:852]
1000000 loops, best of 5: 271 ns per loop
所以 numpy 比 JAX 快 5000 倍。当 JAX 在 CPU 上时,则
%timeit jax_array[435:852]
1000 loops, best of 5: 577 µs per loop
如此之快,但仍然比 numpy 慢 2000 倍。我为此使用 Google Colab 笔记本,因此 installation/CUDA.
应该没有问题我错过了什么吗?我意识到索引对于 JAX 和 numpy 是不同的,正如 the JAX 'sharp edges' documentation 给出的那样,但是我找不到任何方法来执行赋值,例如
new_array = jax_array[435:852]
没有明显放缓。我无法避免对数组进行索引,因为这在我的程序中是必要的。
简短的回答:要加快 JAX 的速度,请使用 jit
。
长答案:
您通常应该期望在 op-by-op 模式下使用 JAX 的单个操作比 numpy 中的类似操作慢。这是因为 JAX 执行有一些固定的每个 python 函数调用开销涉及将编译下推到 XLA。
即使像索引这样看似简单的操作也是根据多个 XLA 操作实现的,其中(在 JIT 之外)每个操作都会增加自己的调用开销。您可以使用 make_jaxpr
转换来查看此序列,以检查函数是如何用原始操作表示的:
from jax import make_jaxpr
f = lambda x: x[435:852]
make_jaxpr(f)(jax_array)
# { lambda ; a.
# let b = broadcast_in_dim[ broadcast_dimensions=( )
# shape=(1,) ] 435
# c = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))
# indices_are_sorted=True
# slice_sizes=(417,)
# unique_indices=True ] a b
# d = broadcast_in_dim[ broadcast_dimensions=(0,)
# shape=(417,) ] c
# in (d,) }
(有关如何阅读此内容的信息,请参阅 Understanding Jaxprs)。
JAX 优于 numpy 的地方不在于单个小操作(其中 JAX 调度开销占主导地位),而是在于通过 jit
转换编译的操作序列。因此,例如,比较索引的 JIT 编译版本和非 JIT 编译版本:
%timeit f(jax_array).block_until_ready()
# 1000 loops, best of 5: 612 µs per loop
f_jit = jit(f)
f_jit(jax_array) # trigger compilation
%timeit f_jit(jax_array).block_until_ready()
# 100000 loops, best of 5: 4.34 µs per loop
(请注意,由于 JAX 的 asynchronous dispatch,准确的微基准测试需要 block_until_ready()
)
JIT 编译此代码可提供 150 倍的加速。由于 JAX 的几毫秒分派开销,它仍然没有 numpy 快,但是对于 JIT,这种开销只会产生一次。当您将微基准测试转移到更复杂的现实世界计算序列时,那几毫秒将不再占主导地位,并且 XLA 编译器提供的优化可以使 JAX 比等效的 numpy 计算快得多。