JAX 运行 比 NumPy 慢吗?
Does JAX run slower than NumPy?
我最近开始学习 JAX。我在 NumPy 中编写了一个简短的片段,并在 JAX 中编写了它的等价物。我原以为 JAX 会更快,但是当我分析代码时,NumPy 代码比 JAX 代码快得多。我想知道这是否普遍正确,或者我的实施是否存在问题。
NumPy代码:
import numpy as np
from time import time as tm
def gp(x):
return np.maximum(np.zeros(x.shape), x)
# -- inputs
n_q1 = 25
n_q2 = 5
x = np.random.rand(1, n_q1) # todo: changes w/ time
y = np.random.rand(1, n_q2) # todo: changes w/ time
# -- activations
n_p1 = 3
n_p2 = 2
v_p1 = np.random.rand(1, n_p1)
v_p2 = np.random.rand(1, n_p2)
a_p1 = 0.5
a_p2 = 0.5
# -- weights
W_q1p1 = np.random.rand(n_q1, n_p1)
W_p2q2 = np.random.rand(n_p2, n_q2)
W_p1p1 = np.random.rand(n_p1, n_p1)
W_p1p2 = np.random.rand(n_p1, n_p2)
W_p2p1 = np.random.rand(n_p2, n_p1)
# -- computation
t1=tm()
for t in range(2000):
z_p1 = np.matmul(v_p1, W_p1p1) + np.matmul(v_p2, W_p2p1) + np.matmul(x, W_q1p1)
v_p1_new = a_p1 * v_p1 + (1 - a_p1) * gp(z_p1)
z_p2 = np.matmul(v_p1, W_p1p2)
v_p2_new = a_p2 * v_p2 + (1 - a_p2) * gp(z_p2)
v_p1, v_p2 = v_p1_new, v_p2_new
print(tm()-t1)
这产生:0.02118515968322754
JAX代码:
from jax import random, nn, numpy as jnp
from time import time as tm
def gp(x):
return nn.relu(x)
# -- inputs
n_q1 = 25
n_q2 = 5
key = random.PRNGKey(0)
x = random.normal(key, (1, n_q1))
y = random.normal(key, (1, n_q2)) # todo: check if I need to advance "key" manually
# -- activations
n_p1 = 3
n_p2 = 2
v_p1 = random.normal(key, (1, n_p1))
v_p2 = random.normal(key, (1, n_p2))
a_p1 = 0.5
a_p2 = 0.5
# -- weights
W_q1p1 = random.normal(key, (n_q1, n_p1))
W_p2q2 = random.normal(key, (n_p2, n_q2))
W_p1p1 = random.normal(key, (n_p1, n_p1))
W_p1p2 = random.normal(key, (n_p1, n_p2))
W_p2p1 = random.normal(key, (n_p2, n_p1))
# -- computation
t1=tm()
for t in range(2000):
z_p1 = jnp.matmul(v_p1, W_p1p1) + jnp.matmul(v_p2, W_p2p1) + jnp.matmul(x, W_q1p1)
v_p1_new = a_p1 * v_p1 + (1 - a_p1) * gp(z_p1)
z_p2 = jnp.matmul(v_p1, W_p1p2)
v_p2_new = a_p2 * v_p2 + (1 - a_p2) * gp(z_p2)
v_p1, v_p2 = v_p1_new, v_p2_new
print(tm()-t1)
这产生:2.5548229217529297
JAX 文档在其常见问题解答中对此有一个有用的部分:https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy
TL;DR:这很复杂。对于 CPU 上的单个矩阵运算,JAX 通常比 NumPy 慢,但 JAX 中 JIT 编译的操作序列通常比 NumPy 快,一旦移至 GPU/TPU,JAX 通常会快得多比 NumPy.
我最近开始学习 JAX。我在 NumPy 中编写了一个简短的片段,并在 JAX 中编写了它的等价物。我原以为 JAX 会更快,但是当我分析代码时,NumPy 代码比 JAX 代码快得多。我想知道这是否普遍正确,或者我的实施是否存在问题。
NumPy代码:
import numpy as np
from time import time as tm
def gp(x):
return np.maximum(np.zeros(x.shape), x)
# -- inputs
n_q1 = 25
n_q2 = 5
x = np.random.rand(1, n_q1) # todo: changes w/ time
y = np.random.rand(1, n_q2) # todo: changes w/ time
# -- activations
n_p1 = 3
n_p2 = 2
v_p1 = np.random.rand(1, n_p1)
v_p2 = np.random.rand(1, n_p2)
a_p1 = 0.5
a_p2 = 0.5
# -- weights
W_q1p1 = np.random.rand(n_q1, n_p1)
W_p2q2 = np.random.rand(n_p2, n_q2)
W_p1p1 = np.random.rand(n_p1, n_p1)
W_p1p2 = np.random.rand(n_p1, n_p2)
W_p2p1 = np.random.rand(n_p2, n_p1)
# -- computation
t1=tm()
for t in range(2000):
z_p1 = np.matmul(v_p1, W_p1p1) + np.matmul(v_p2, W_p2p1) + np.matmul(x, W_q1p1)
v_p1_new = a_p1 * v_p1 + (1 - a_p1) * gp(z_p1)
z_p2 = np.matmul(v_p1, W_p1p2)
v_p2_new = a_p2 * v_p2 + (1 - a_p2) * gp(z_p2)
v_p1, v_p2 = v_p1_new, v_p2_new
print(tm()-t1)
这产生:0.02118515968322754
JAX代码:
from jax import random, nn, numpy as jnp
from time import time as tm
def gp(x):
return nn.relu(x)
# -- inputs
n_q1 = 25
n_q2 = 5
key = random.PRNGKey(0)
x = random.normal(key, (1, n_q1))
y = random.normal(key, (1, n_q2)) # todo: check if I need to advance "key" manually
# -- activations
n_p1 = 3
n_p2 = 2
v_p1 = random.normal(key, (1, n_p1))
v_p2 = random.normal(key, (1, n_p2))
a_p1 = 0.5
a_p2 = 0.5
# -- weights
W_q1p1 = random.normal(key, (n_q1, n_p1))
W_p2q2 = random.normal(key, (n_p2, n_q2))
W_p1p1 = random.normal(key, (n_p1, n_p1))
W_p1p2 = random.normal(key, (n_p1, n_p2))
W_p2p1 = random.normal(key, (n_p2, n_p1))
# -- computation
t1=tm()
for t in range(2000):
z_p1 = jnp.matmul(v_p1, W_p1p1) + jnp.matmul(v_p2, W_p2p1) + jnp.matmul(x, W_q1p1)
v_p1_new = a_p1 * v_p1 + (1 - a_p1) * gp(z_p1)
z_p2 = jnp.matmul(v_p1, W_p1p2)
v_p2_new = a_p2 * v_p2 + (1 - a_p2) * gp(z_p2)
v_p1, v_p2 = v_p1_new, v_p2_new
print(tm()-t1)
这产生:2.5548229217529297
JAX 文档在其常见问题解答中对此有一个有用的部分:https://jax.readthedocs.io/en/latest/faq.html#is-jax-faster-than-numpy
TL;DR:这很复杂。对于 CPU 上的单个矩阵运算,JAX 通常比 NumPy 慢,但 JAX 中 JIT 编译的操作序列通常比 NumPy 快,一旦移至 GPU/TPU,JAX 通常会快得多比 NumPy.