jax 的矢量化指南

Vectorization guidelnes for jax

假设我有一个函数(为简单起见,两个系列之间的协方差,尽管问题更笼统):

def cov(x, y):
   return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))

现在我有一个“数据框”D(一个二维数组,其列是我的系列)我想矢量化 cov函数产生协方差矩阵。现在,有一个显而易见的方法:

cov1 = jax.vmap(cov, in_axes=(None, 1))
cov2 = jax.vmap(cov1, in_axes=(1, None))

但这似乎有点笨拙。有没有“规范”的方式来做到这一点?

如果您想用 vmap 表达与嵌套 for 循环等效的逻辑,那么是的,它需要嵌套 vmap。我认为您所写的内容可能是您可以为这样的操作获得的规范,尽管如果使用装饰器编写可能会稍微清楚一些:

from functools import partial

@partial(jax.vmap, in_axes=(1, None))
@partial(jax.vmap, in_axes=(None, 1))
def cov(x, y):
   return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))

但是,对于这个特定的函数,请注意,如果您愿意,您可以使用单个点积来表达同样的事情:

result = jnp.dot((x - x.mean(0)).T, (y - y.mean(0)))