加速 JAX 中的嵌套 for 循环

Accelerating nested for-loops in JAX

我想使用 JAX 的 jit 方法加速下面示例中的嵌套 for 循环。 但是,编译时间很长,编译后的运行时间比不使用jit.

的版本更慢

我使用 jit 正确吗?我应该在这里使用 JAX 中的其他功能吗?

import time
import jax.numpy as jnp
from jax import jit
from jax import random

key = random.PRNGKey(seed=0)

width = 32
height = 64

w = random.normal(key=key, shape=(height, width))

def forward():
    a = jnp.zeros(shape=(height, width + 1))

    for i in range(height):
        a = a.at[i, 0].add(1.0)

    for j in range(width):
        for i in range(1, height-1):
            z = a[i-1, j] * w[i-1, j] \
                + a[i, j] * w[i, j] \
                + a[i+1, j] * w[i+1, j]
            a = a.at[i, j+1].set(z)

t0 = time.time()
forward()
print(time.time()-t0)

feedforward_jit = jit(forward)

t0 = time.time()
feedforward_jit()
print(time.time()-t0)

对您的问题的简短回答是:要优化循环,您应该尽一切可能从程序中删除循环。

JAX(如 NumPy)是一种建立在数组操作基础上的语言,任何时候你求助于循环数组的维度,JAX(如 NumPy)都会比你希望的慢。在 JIT 编译期间尤其如此:JAX 会在将操作发送到 XLA 之前展平循环,而 XLA 编译时间的比例大致与发送给它的操作数的平方成正比,因此嵌套循环是快速创建 非常 编译很慢。

那么如何避免这些循环呢?首先,让我们重新定义您的函数,使其接受输入和 returns 输出(考虑到 JAX 的死代码消除和异步调度,我认为您的初始基准测试并没有告诉您您认为它们是什么;请参阅 Benchmarking JAX code 一些提示):

def forward(w):
  height, width = w.shape
  a = jnp.zeros(shape=(height, width + 1))

  for i in range(height):
    a = a.at[i, 0].add(1.0)

  for j in range(width):
    for i in range(1, height-1):
      z = (a[i-1, j] * w[i-1, j]
           + a[i, j] * w[i, j]
           + a[i+1, j] * w[i+1, j])
      a = a.at[i, j+1].set(z)
  return a

第一个循环是可以用单行向量化更新代替的情况:a = a.at[:, 0].set(1)。查看下一个块的内部循环,代码似乎沿着每一列进行了卷积。让我们使用 jnp.convolve 来更有效地做到这一点。使用这两个优化结果:

def forward2(w):
  height, width = w.shape
  a = jnp.zeros((height, width + 1)).at[:, 0].set(1)
  kernel = jnp.ones(3)
  for j in range(width):
    conv = jnp.convolve(a[:, j] * w[:, j], kernel, mode='valid')
    a = a.at[1:-1, j + 1].set(conv)
  return a

接下来让我们看看循环宽度。这里比较棘手,因为每次迭代都取决于上一次的结果。一种表达方式是 lax.scan, which is one of JAX's built-in control flow operators。你可以这样做:

def forward3(w):
  def body(carry, w):
    conv = jnp.convolve(carry * w, kernel, mode='valid')
    out = jnp.zeros_like(w).at[1:-1].set(conv)
    return out, out
  init = jnp.ones(w.shape[0])
  kernel = jnp.ones(3)
  return jnp.vstack([
      init, lax.scan(body, jnp.ones(w.shape[0]), w.T)[1]]).T

我们可以快速确认这三种方法给出相同的输出:

width = 32
height = 64
w = random.normal(key=key, shape=(height, width))

result1 = forward(w)
result2 = forward2(w)
result3 = forward3(w)

assert jnp.allclose(result1, result2)
assert jnp.allclose(result2, result3)

使用 IPython 的 %time 魔法,我们可以大致了解每种方法的计算时间,这里是 CPU 后端(注意使用 block_until_ready() 来说明 JAX 的 Asynchronous dispatch):

%time forward(w).block_until_ready()
# CPU times: user 23 s, sys: 248 ms, total: 23.3 s
# Wall time: 22.9 s

%time forward2(w).block_until_ready()
# CPU times: user 117 ms, sys: 866 µs, total: 118 ms
# Wall time: 118 ms

%time forward3(w).block_until_ready()
# CPU times: user 93.2 ms, sys: 2.96 ms, total: 96.1 ms
# Wall time: 94 ms

您可以在 https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow 阅读更多关于 JAX 和控制流的信息。