有没有一种方法可以加快使用 JAX 对向量进行索引的速度?

Is there a way to speed up indexing a vector with JAX?

我正在索引向量并使用 JAX,但我注意到在简单索引数组时与 numpy 相比速度明显变慢。例如,考虑在 JAX numpy 和普通 numpy 中制作一个基本数组:

import jax.numpy as jnp
import numpy as onp 
jax_array = jnp.ones((1000,))
numpy_array = onp.ones(1000)

然后简单地在两个整数之间建立索引,对于 JAX(在 GPU 上)这给出了时间:

%timeit jax_array[435:852]

1000 loops, best of 5: 1.38 ms per loop

对于 numpy,这给出了时间:

%timeit numpy_array[435:852]

1000000 loops, best of 5: 271 ns per loop

所以 numpy 比 JAX 快 5000 倍。当 JAX 在 CPU 上时,则

%timeit jax_array[435:852]

1000 loops, best of 5: 577 µs per loop

如此之快,但仍然比 numpy 慢 2000 倍。我为此使用 Google Colab 笔记本,因此 installation/CUDA.

应该没有问题

我错过了什么吗?我意识到索引对于 JAX 和 numpy 是不同的,正如 the JAX 'sharp edges' documentation 给出的那样,但是我找不到任何方法来执行赋值,例如

new_array = jax_array[435:852]

没有明显放缓。我无法避免对数组进行索引,因为这在我的程序中是必要的。

简短的回答:要加快 JAX 的速度,请使用 jit

长答案:

您通常应该期望在 op-by-op 模式下使用 JAX 的单个操作比 numpy 中的类似操作慢。这是因为 JAX 执行有一些固定的每个 python 函数调用开销涉及将编译下推到 XLA。

即使像索引这样看似简单的操作也是根据多个 XLA 操作实现的,其中(在 JIT 之外)每个操作都会增加自己的调用开销。您可以使用 make_jaxpr 转换来查看此序列,以检查函数是如何用原始操作表示的:

from jax import make_jaxpr
f = lambda x: x[435:852]
make_jaxpr(f)(jax_array)
# { lambda  ; a.
#   let b = broadcast_in_dim[ broadcast_dimensions=(  )
#                             shape=(1,) ] 435
#       c = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))
#                   indices_are_sorted=True
#                   slice_sizes=(417,)
#                   unique_indices=True ] a b
#       d = broadcast_in_dim[ broadcast_dimensions=(0,)
#                             shape=(417,) ] c
#   in (d,) }

(有关如何阅读此内容的信息,请参阅 Understanding Jaxprs)。

JAX 优于 numpy 的地方不在于单个小操作(其中 JAX 调度开销占主导地位),而是在于通过 jit 转换编译的操作序列。因此,例如,比较索引的 JIT 编译版本和非 JIT 编译版本:

%timeit f(jax_array).block_until_ready()
# 1000 loops, best of 5: 612 µs per loop

f_jit = jit(f)
f_jit(jax_array)  # trigger compilation
%timeit f_jit(jax_array).block_until_ready()
# 100000 loops, best of 5: 4.34 µs per loop

(请注意,由于 JAX 的 asynchronous dispatch,准确的微基准测试需要 block_until_ready()

JIT 编译此代码可提供 150 倍的加速。由于 JAX 的几毫秒分派开销,它仍然没有 numpy 快,但是对于 JIT,这种开销只会产生一次。当您将微基准测试转移到更复杂的现实世界计算序列时,那几毫秒将不再占主导地位,并且 XLA 编译器提供的优化可以使 JAX 比等效的 numpy 计算快得多。