JAX vmap 行为
JAX vmap behaviour
我试图了解 JAX vmap 的行为,所以我编写了以下代码:
import jax.numpy as jnp
from jax import vmap
def what(a,b,c):
z = jnp.dot(a,b)
return z + c
v_what = vmap(what, in_axes=(None,0,None))
a = jnp.array([1,1,3])
b = jnp.array([2,2])
c = 1.0
v_what(a,b,c)
输出为:
DeviceArray([[3., 3., 7.],
[3., 3., 7.]], dtype=float32)
我知道唯一被改变的输入是 b
,但是有人能解释一下为什么会这样吗?我对函数进行矢量化后,点积的表现如何?
您已指定转换后的函数应映射到 b
的第一个轴上,而不映射到 a
或 c
的任何轴上。粗略地说,您已经创建了一个执行此操作的映射函数:
def v_what(a, b, c):
return jnp.stack([what(a, b_i, c) for b_i in b], axis=0)
对于您的输入,在每一行中,点积看起来像 jnp.dot(a, 2)
,结果相当于 a * 2
。
我试图了解 JAX vmap 的行为,所以我编写了以下代码:
import jax.numpy as jnp
from jax import vmap
def what(a,b,c):
z = jnp.dot(a,b)
return z + c
v_what = vmap(what, in_axes=(None,0,None))
a = jnp.array([1,1,3])
b = jnp.array([2,2])
c = 1.0
v_what(a,b,c)
输出为:
DeviceArray([[3., 3., 7.],
[3., 3., 7.]], dtype=float32)
我知道唯一被改变的输入是 b
,但是有人能解释一下为什么会这样吗?我对函数进行矢量化后,点积的表现如何?
您已指定转换后的函数应映射到 b
的第一个轴上,而不映射到 a
或 c
的任何轴上。粗略地说,您已经创建了一个执行此操作的映射函数:
def v_what(a, b, c):
return jnp.stack([what(a, b_i, c) for b_i in b], axis=0)
对于您的输入,在每一行中,点积看起来像 jnp.dot(a, 2)
,结果相当于 a * 2
。