加速 JAX 中的嵌套 for 循环
Accelerating nested for-loops in JAX
我想使用 JAX 的 jit
方法加速下面示例中的嵌套 for 循环。
但是,编译时间很长,编译后的运行时间比不使用jit
.
的版本更慢
我使用 jit
正确吗?我应该在这里使用 JAX 中的其他功能吗?
import time
import jax.numpy as jnp
from jax import jit
from jax import random
key = random.PRNGKey(seed=0)
width = 32
height = 64
w = random.normal(key=key, shape=(height, width))
def forward():
a = jnp.zeros(shape=(height, width + 1))
for i in range(height):
a = a.at[i, 0].add(1.0)
for j in range(width):
for i in range(1, height-1):
z = a[i-1, j] * w[i-1, j] \
+ a[i, j] * w[i, j] \
+ a[i+1, j] * w[i+1, j]
a = a.at[i, j+1].set(z)
t0 = time.time()
forward()
print(time.time()-t0)
feedforward_jit = jit(forward)
t0 = time.time()
feedforward_jit()
print(time.time()-t0)
对您的问题的简短回答是:要优化循环,您应该尽一切可能从程序中删除循环。
JAX(如 NumPy)是一种建立在数组操作基础上的语言,任何时候你求助于循环数组的维度,JAX(如 NumPy)都会比你希望的慢。在 JIT 编译期间尤其如此:JAX 会在将操作发送到 XLA 之前展平循环,而 XLA 编译时间的比例大致与发送给它的操作数的平方成正比,因此嵌套循环是快速创建 非常 编译很慢。
那么如何避免这些循环呢?首先,让我们重新定义您的函数,使其接受输入和 returns 输出(考虑到 JAX 的死代码消除和异步调度,我认为您的初始基准测试并没有告诉您您认为它们是什么;请参阅 Benchmarking JAX code 一些提示):
def forward(w):
height, width = w.shape
a = jnp.zeros(shape=(height, width + 1))
for i in range(height):
a = a.at[i, 0].add(1.0)
for j in range(width):
for i in range(1, height-1):
z = (a[i-1, j] * w[i-1, j]
+ a[i, j] * w[i, j]
+ a[i+1, j] * w[i+1, j])
a = a.at[i, j+1].set(z)
return a
第一个循环是可以用单行向量化更新代替的情况:a = a.at[:, 0].set(1)
。查看下一个块的内部循环,代码似乎沿着每一列进行了卷积。让我们使用 jnp.convolve
来更有效地做到这一点。使用这两个优化结果:
def forward2(w):
height, width = w.shape
a = jnp.zeros((height, width + 1)).at[:, 0].set(1)
kernel = jnp.ones(3)
for j in range(width):
conv = jnp.convolve(a[:, j] * w[:, j], kernel, mode='valid')
a = a.at[1:-1, j + 1].set(conv)
return a
接下来让我们看看循环宽度。这里比较棘手,因为每次迭代都取决于上一次的结果。一种表达方式是 lax.scan
, which is one of JAX's built-in control flow operators。你可以这样做:
def forward3(w):
def body(carry, w):
conv = jnp.convolve(carry * w, kernel, mode='valid')
out = jnp.zeros_like(w).at[1:-1].set(conv)
return out, out
init = jnp.ones(w.shape[0])
kernel = jnp.ones(3)
return jnp.vstack([
init, lax.scan(body, jnp.ones(w.shape[0]), w.T)[1]]).T
我们可以快速确认这三种方法给出相同的输出:
width = 32
height = 64
w = random.normal(key=key, shape=(height, width))
result1 = forward(w)
result2 = forward2(w)
result3 = forward3(w)
assert jnp.allclose(result1, result2)
assert jnp.allclose(result2, result3)
使用 IPython 的 %time
魔法,我们可以大致了解每种方法的计算时间,这里是 CPU 后端(注意使用 block_until_ready()
来说明 JAX 的 Asynchronous dispatch):
%time forward(w).block_until_ready()
# CPU times: user 23 s, sys: 248 ms, total: 23.3 s
# Wall time: 22.9 s
%time forward2(w).block_until_ready()
# CPU times: user 117 ms, sys: 866 µs, total: 118 ms
# Wall time: 118 ms
%time forward3(w).block_until_ready()
# CPU times: user 93.2 ms, sys: 2.96 ms, total: 96.1 ms
# Wall time: 94 ms
您可以在 https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow 阅读更多关于 JAX 和控制流的信息。
我想使用 JAX 的 jit
方法加速下面示例中的嵌套 for 循环。
但是,编译时间很长,编译后的运行时间比不使用jit
.
我使用 jit
正确吗?我应该在这里使用 JAX 中的其他功能吗?
import time
import jax.numpy as jnp
from jax import jit
from jax import random
key = random.PRNGKey(seed=0)
width = 32
height = 64
w = random.normal(key=key, shape=(height, width))
def forward():
a = jnp.zeros(shape=(height, width + 1))
for i in range(height):
a = a.at[i, 0].add(1.0)
for j in range(width):
for i in range(1, height-1):
z = a[i-1, j] * w[i-1, j] \
+ a[i, j] * w[i, j] \
+ a[i+1, j] * w[i+1, j]
a = a.at[i, j+1].set(z)
t0 = time.time()
forward()
print(time.time()-t0)
feedforward_jit = jit(forward)
t0 = time.time()
feedforward_jit()
print(time.time()-t0)
对您的问题的简短回答是:要优化循环,您应该尽一切可能从程序中删除循环。
JAX(如 NumPy)是一种建立在数组操作基础上的语言,任何时候你求助于循环数组的维度,JAX(如 NumPy)都会比你希望的慢。在 JIT 编译期间尤其如此:JAX 会在将操作发送到 XLA 之前展平循环,而 XLA 编译时间的比例大致与发送给它的操作数的平方成正比,因此嵌套循环是快速创建 非常 编译很慢。
那么如何避免这些循环呢?首先,让我们重新定义您的函数,使其接受输入和 returns 输出(考虑到 JAX 的死代码消除和异步调度,我认为您的初始基准测试并没有告诉您您认为它们是什么;请参阅 Benchmarking JAX code 一些提示):
def forward(w):
height, width = w.shape
a = jnp.zeros(shape=(height, width + 1))
for i in range(height):
a = a.at[i, 0].add(1.0)
for j in range(width):
for i in range(1, height-1):
z = (a[i-1, j] * w[i-1, j]
+ a[i, j] * w[i, j]
+ a[i+1, j] * w[i+1, j])
a = a.at[i, j+1].set(z)
return a
第一个循环是可以用单行向量化更新代替的情况:a = a.at[:, 0].set(1)
。查看下一个块的内部循环,代码似乎沿着每一列进行了卷积。让我们使用 jnp.convolve
来更有效地做到这一点。使用这两个优化结果:
def forward2(w):
height, width = w.shape
a = jnp.zeros((height, width + 1)).at[:, 0].set(1)
kernel = jnp.ones(3)
for j in range(width):
conv = jnp.convolve(a[:, j] * w[:, j], kernel, mode='valid')
a = a.at[1:-1, j + 1].set(conv)
return a
接下来让我们看看循环宽度。这里比较棘手,因为每次迭代都取决于上一次的结果。一种表达方式是 lax.scan
, which is one of JAX's built-in control flow operators。你可以这样做:
def forward3(w):
def body(carry, w):
conv = jnp.convolve(carry * w, kernel, mode='valid')
out = jnp.zeros_like(w).at[1:-1].set(conv)
return out, out
init = jnp.ones(w.shape[0])
kernel = jnp.ones(3)
return jnp.vstack([
init, lax.scan(body, jnp.ones(w.shape[0]), w.T)[1]]).T
我们可以快速确认这三种方法给出相同的输出:
width = 32
height = 64
w = random.normal(key=key, shape=(height, width))
result1 = forward(w)
result2 = forward2(w)
result3 = forward3(w)
assert jnp.allclose(result1, result2)
assert jnp.allclose(result2, result3)
使用 IPython 的 %time
魔法,我们可以大致了解每种方法的计算时间,这里是 CPU 后端(注意使用 block_until_ready()
来说明 JAX 的 Asynchronous dispatch):
%time forward(w).block_until_ready()
# CPU times: user 23 s, sys: 248 ms, total: 23.3 s
# Wall time: 22.9 s
%time forward2(w).block_until_ready()
# CPU times: user 117 ms, sys: 866 µs, total: 118 ms
# Wall time: 118 ms
%time forward3(w).block_until_ready()
# CPU times: user 93.2 ms, sys: 2.96 ms, total: 96.1 ms
# Wall time: 94 ms
您可以在 https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow 阅读更多关于 JAX 和控制流的信息。