为什么这个函数在 JAX 和 numpy 中比较慢?
Why is this function slower in JAX vs numpy?
我有如下所示的 numpy 函数,我正在尝试使用 JAX 对其进行优化,但无论出于何种原因,它的速度都较慢。
有人可以指出我可以做些什么来提高这里的性能吗?我怀疑它与 Cg_new 的列表理解有关,但将其分开不会在 JAX 中产生任何进一步的性能提升。
import numpy as np
def testFunction_numpy(C, Mi, C_new, Mi_new):
Wg_new = np.zeros((len(Mi_new[:,0]), len(Mi[0])))
Cg_new = np.zeros((1, len(Mi[0])))
invertCsensor_new = np.linalg.inv(C_new)
Wg_new = np.dot(invertCsensor_new, Mi_new)
Cg_new = [np.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))]
return C_new, Mi_new, Wg_new, Cg_new
C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)
%timeit testFunction_numpy(C, Mi, C_new, Mi_new)
#1000 loops, best of 3: 1.73 ms per loop
这是 JAX 等效项:
import jax.numpy as jnp
import numpy as np
import jax
def testFunction_JAX(C, Mi, C_new, Mi_new):
Wg_new = jnp.zeros((len(Mi_new[:,0]), len(Mi[0])))
Cg_new = jnp.zeros((1, len(Mi[0])))
invertCsensor_new = jnp.linalg.inv(C_new)
Wg_new = jnp.dot(invertCsensor_new, Mi_new)
Cg_new = [jnp.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))]
return C_new, Mi_new, Wg_new, Cg_new
C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)
C = jnp.asarray(C)
Mi = jnp.asarray(Mi)
C_new = jnp.asarray(C_new)
Mi_new = jnp.asarray(Mi_new)
jitter = jax.jit(testFunction_JAX)
%timeit jitter(C, Mi, C_new, Mi_new)
#1 loop, best of 3: 4.96 ms per loop
有关 JAX 和 NumPy 之间基准比较的一般注意事项,请参阅 https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy
至于您的特定代码:当 JAX jit 编译遇到 Python 控制流(包括列表理解)时,它会有效地展平循环并分阶段执行完整的操作序列。这会导致 jit 编译时间变慢和代码不理想。幸运的是,您的函数中的列表理解很容易用原生 numpy 广播来表达。此外,您还可以进行另外两项改进:
- 在计算之前不需要转发声明
Wg_new
和 Cg_new
- 在计算
dot(inv(A), B)
时,使用 np.linalg.solve
比显式计算逆运算更高效和精确。
对 numpy 和 JAX 版本进行这三项改进会产生以下结果:
def testFunction_numpy_v2(C, Mi, C_new, Mi_new):
Wg_new = np.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return C_new, Mi_new, Wg_new, Cg_new
@jax.jit
def testFunction_JAX_v2(C, Mi, C_new, Mi_new):
Wg_new = jnp.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return C_new, Mi_new, Wg_new, Cg_new
%timeit testFunction_numpy_v2(C, Mi, C_new, Mi_new)
# 1000 loops, best of 3: 1.11 ms per loop
%timeit testFunction_JAX_v2(C_jax, Mi_jax, C_new_jax, Mi_new_jax)
# 1000 loops, best of 3: 1.35 ms per loop
由于改进了实施,这两个功能都比以前快了一点。但是,您会注意到 JAX 在这里仍然比 numpy 慢;这在某种程度上是意料之中的,因为对于这种简单程度的功能,JAX 和 numpy 都有效地生成了在 CPU 架构上执行的相同的短系列 BLAS 和 LAPACK 调用。与 numpy 的参考实现相比,没有太大的改进空间,而且对于如此小的数组,JAX 的开销是显而易见的。
我用 perfplot 在一系列问题规模上测试了这个问题。结果:jax 稍微快了一点。 jax 在这里没有优于 numpy 的原因是它在 CPU 上是 运行(就像 NumPy),而在这里,NumPy 已经非常优化。 (它在后台使用 BLAS/LAPACK。)
重现情节的代码:
import jax.numpy as jnp
import jax
import numpy as np
import perfplot
def setup(n):
C_new = np.random.rand(n, n)
Mi_new = np.random.rand(n, 8)
return C_new, Mi_new
def testFunction_numpy_v2(C_new, Mi_new):
Wg_new = np.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return Wg_new, Cg_new
@jax.jit
def testFunction_JAX_v2(C_new, Mi_new):
Wg_new = jnp.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return Wg_new, Cg_new
b = perfplot.bench(
setup=setup,
kernels=[testFunction_numpy_v2, testFunction_JAX_v2],
n_range=[2 ** k for k in range(14)],
equality_check=None
)
b.save("out.png")
b.show()
我有如下所示的 numpy 函数,我正在尝试使用 JAX 对其进行优化,但无论出于何种原因,它的速度都较慢。
有人可以指出我可以做些什么来提高这里的性能吗?我怀疑它与 Cg_new 的列表理解有关,但将其分开不会在 JAX 中产生任何进一步的性能提升。
import numpy as np
def testFunction_numpy(C, Mi, C_new, Mi_new):
Wg_new = np.zeros((len(Mi_new[:,0]), len(Mi[0])))
Cg_new = np.zeros((1, len(Mi[0])))
invertCsensor_new = np.linalg.inv(C_new)
Wg_new = np.dot(invertCsensor_new, Mi_new)
Cg_new = [np.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))]
return C_new, Mi_new, Wg_new, Cg_new
C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)
%timeit testFunction_numpy(C, Mi, C_new, Mi_new)
#1000 loops, best of 3: 1.73 ms per loop
这是 JAX 等效项:
import jax.numpy as jnp
import numpy as np
import jax
def testFunction_JAX(C, Mi, C_new, Mi_new):
Wg_new = jnp.zeros((len(Mi_new[:,0]), len(Mi[0])))
Cg_new = jnp.zeros((1, len(Mi[0])))
invertCsensor_new = jnp.linalg.inv(C_new)
Wg_new = jnp.dot(invertCsensor_new, Mi_new)
Cg_new = [jnp.dot(((-0.5*(Mi_new[:,m].conj().T))), (Wg_new[:,m])) for m in range(0, len(Mi[0]))]
return C_new, Mi_new, Wg_new, Cg_new
C = np.random.rand(483,483)
Mi = np.random.rand(483,8)
C_new = np.random.rand(198,198)
Mi_new = np.random.rand(198,8)
C = jnp.asarray(C)
Mi = jnp.asarray(Mi)
C_new = jnp.asarray(C_new)
Mi_new = jnp.asarray(Mi_new)
jitter = jax.jit(testFunction_JAX)
%timeit jitter(C, Mi, C_new, Mi_new)
#1 loop, best of 3: 4.96 ms per loop
有关 JAX 和 NumPy 之间基准比较的一般注意事项,请参阅 https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy
至于您的特定代码:当 JAX jit 编译遇到 Python 控制流(包括列表理解)时,它会有效地展平循环并分阶段执行完整的操作序列。这会导致 jit 编译时间变慢和代码不理想。幸运的是,您的函数中的列表理解很容易用原生 numpy 广播来表达。此外,您还可以进行另外两项改进:
- 在计算之前不需要转发声明
Wg_new
和Cg_new
- 在计算
dot(inv(A), B)
时,使用np.linalg.solve
比显式计算逆运算更高效和精确。
对 numpy 和 JAX 版本进行这三项改进会产生以下结果:
def testFunction_numpy_v2(C, Mi, C_new, Mi_new):
Wg_new = np.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return C_new, Mi_new, Wg_new, Cg_new
@jax.jit
def testFunction_JAX_v2(C, Mi, C_new, Mi_new):
Wg_new = jnp.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return C_new, Mi_new, Wg_new, Cg_new
%timeit testFunction_numpy_v2(C, Mi, C_new, Mi_new)
# 1000 loops, best of 3: 1.11 ms per loop
%timeit testFunction_JAX_v2(C_jax, Mi_jax, C_new_jax, Mi_new_jax)
# 1000 loops, best of 3: 1.35 ms per loop
由于改进了实施,这两个功能都比以前快了一点。但是,您会注意到 JAX 在这里仍然比 numpy 慢;这在某种程度上是意料之中的,因为对于这种简单程度的功能,JAX 和 numpy 都有效地生成了在 CPU 架构上执行的相同的短系列 BLAS 和 LAPACK 调用。与 numpy 的参考实现相比,没有太大的改进空间,而且对于如此小的数组,JAX 的开销是显而易见的。
我用 perfplot 在一系列问题规模上测试了这个问题。结果:jax 稍微快了一点。 jax 在这里没有优于 numpy 的原因是它在 CPU 上是 运行(就像 NumPy),而在这里,NumPy 已经非常优化。 (它在后台使用 BLAS/LAPACK。)
重现情节的代码:
import jax.numpy as jnp
import jax
import numpy as np
import perfplot
def setup(n):
C_new = np.random.rand(n, n)
Mi_new = np.random.rand(n, 8)
return C_new, Mi_new
def testFunction_numpy_v2(C_new, Mi_new):
Wg_new = np.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return Wg_new, Cg_new
@jax.jit
def testFunction_JAX_v2(C_new, Mi_new):
Wg_new = jnp.linalg.solve(C_new, Mi_new)
Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
return Wg_new, Cg_new
b = perfplot.bench(
setup=setup,
kernels=[testFunction_numpy_v2, testFunction_JAX_v2],
n_range=[2 ** k for k in range(14)],
equality_check=None
)
b.save("out.png")
b.show()