JAX vmap 行为

JAX vmap behaviour

我试图了解 JAX vmap 的行为,所以我编写了以下代码:

import jax.numpy as jnp
from jax import vmap

def what(a,b,c):
  z = jnp.dot(a,b)
  return z + c

v_what = vmap(what, in_axes=(None,0,None))

a = jnp.array([1,1,3])
b = jnp.array([2,2])
c = 1.0

v_what(a,b,c)

输出为:

DeviceArray([[3., 3., 7.],
             [3., 3., 7.]], dtype=float32)

我知道唯一被改变的输入是 b,但是有人能解释一下为什么会这样吗?我对函数进行矢量化后,点积的表现如何?

您已指定转换后的函数应映射到 b 的第一个轴上,而不映射到 ac 的任何轴上。粗略地说,您已经创建了一个执行此操作的映射函数:

def v_what(a, b, c):
  return jnp.stack([what(a, b_i, c) for b_i in b], axis=0)

对于您的输入,在每一行中,点积看起来像 jnp.dot(a, 2),结果相当于 a * 2