两种对矩阵元素求幂的方法的比较
Compasion of two approaches of exponentiating elements of a matrix
我有两种对 jnp = jax.numpy
中的矩阵求幂的方法。一种
直截了当的一个:
jnp.exp(-X/reg)
还有一些额外的操作:
def exp_reg(X, reg):
K = jnp.empty_like(X)
K = jnp.divide(X, -reg)
return jnp.exp(K)
然而,当我测试它们时:
%timeit jnp.exp(-X/reg).block_until_ready()
%timeit exp_reg(X, reg).block_until_ready()
尽管表面上有一些额外的开销,但第二种方法的表现却优于其他方法。我有 运行 一个 %timeit
矩阵大小为 2000 x 2000:
7.85 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.19 ms ± 52.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
为什么会这样?
这里的区别在于操作顺序。
在 jnp.exp(-X/reg)
中,您对 X
的每个条目取反,然后将结果的每个条目除以 reg
。这是数组 X
.
的两次遍历
在 exp_reg
中你正在否定 reg
(这可能是一个标量值?),然后将 X
除以结果。这是 X
.
的一次传球
如果 X
很大,我希望第一种方法比第二种方法稍慢,因为 X
.
的多次传递
幸运的是,由于您使用的是 JAX,因此您可以 jit
编译您的代码,在这种情况下,XLA 通常可以优化这些等效的操作顺序。事实上,对于你的两个函数,编译消除了差异:
from jax import jit
import jax.numpy as jnp
import numpy as np
def exp_reg1(X, reg):
return jnp.exp(-X/reg)
def exp_reg2(X, reg):
K = jnp.divide(X, -reg)
return jnp.exp(K)
X = jnp.array(np.random.rand(1000, 1000))
reg = 2.0
%timeit exp_reg1(X, reg)
# 100 loops, best of 3: 3.17 ms per loop
%timeit exp_reg2(X, reg)
# 100 loops, best of 3: 2.2 ms per loop
# Trigger compilation
jit(exp_reg1)(X, reg)
jit(exp_reg2)(X, reg)
%timeit jit(exp_reg1)(X, reg)
# 1000 loops, best of 3: 1.92 ms per loop
%timeit jit(exp_reg2)(X, reg)
# 100 loops, best of 3: 1.84 ms per loop
(旁注:在将操作结果分配给同名变量之前,没有理由 pre-allocate 空数组 K
)。
我有两种对 jnp = jax.numpy
中的矩阵求幂的方法。一种
直截了当的一个:
jnp.exp(-X/reg)
还有一些额外的操作:
def exp_reg(X, reg):
K = jnp.empty_like(X)
K = jnp.divide(X, -reg)
return jnp.exp(K)
然而,当我测试它们时:
%timeit jnp.exp(-X/reg).block_until_ready()
%timeit exp_reg(X, reg).block_until_ready()
尽管表面上有一些额外的开销,但第二种方法的表现却优于其他方法。我有 运行 一个 %timeit
矩阵大小为 2000 x 2000:
7.85 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.19 ms ± 52.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
为什么会这样?
这里的区别在于操作顺序。
在 jnp.exp(-X/reg)
中,您对 X
的每个条目取反,然后将结果的每个条目除以 reg
。这是数组 X
.
在 exp_reg
中你正在否定 reg
(这可能是一个标量值?),然后将 X
除以结果。这是 X
.
如果 X
很大,我希望第一种方法比第二种方法稍慢,因为 X
.
幸运的是,由于您使用的是 JAX,因此您可以 jit
编译您的代码,在这种情况下,XLA 通常可以优化这些等效的操作顺序。事实上,对于你的两个函数,编译消除了差异:
from jax import jit
import jax.numpy as jnp
import numpy as np
def exp_reg1(X, reg):
return jnp.exp(-X/reg)
def exp_reg2(X, reg):
K = jnp.divide(X, -reg)
return jnp.exp(K)
X = jnp.array(np.random.rand(1000, 1000))
reg = 2.0
%timeit exp_reg1(X, reg)
# 100 loops, best of 3: 3.17 ms per loop
%timeit exp_reg2(X, reg)
# 100 loops, best of 3: 2.2 ms per loop
# Trigger compilation
jit(exp_reg1)(X, reg)
jit(exp_reg2)(X, reg)
%timeit jit(exp_reg1)(X, reg)
# 1000 loops, best of 3: 1.92 ms per loop
%timeit jit(exp_reg2)(X, reg)
# 100 loops, best of 3: 1.84 ms per loop
(旁注:在将操作结果分配给同名变量之前,没有理由 pre-allocate 空数组 K
)。