jax 的矢量化指南
Vectorization guidelnes for jax
假设我有一个函数(为简单起见,两个系列之间的协方差,尽管问题更笼统):
def cov(x, y):
return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))
现在我有一个“数据框”D
(一个二维数组,其列是我的系列)我想矢量化 cov
函数产生协方差矩阵。现在,有一个显而易见的方法:
cov1 = jax.vmap(cov, in_axes=(None, 1))
cov2 = jax.vmap(cov1, in_axes=(1, None))
但这似乎有点笨拙。有没有“规范”的方式来做到这一点?
如果您想用 vmap
表达与嵌套 for
循环等效的逻辑,那么是的,它需要嵌套 vmap。我认为您所写的内容可能是您可以为这样的操作获得的规范,尽管如果使用装饰器编写可能会稍微清楚一些:
from functools import partial
@partial(jax.vmap, in_axes=(1, None))
@partial(jax.vmap, in_axes=(None, 1))
def cov(x, y):
return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))
但是,对于这个特定的函数,请注意,如果您愿意,您可以使用单个点积来表达同样的事情:
result = jnp.dot((x - x.mean(0)).T, (y - y.mean(0)))
假设我有一个函数(为简单起见,两个系列之间的协方差,尽管问题更笼统):
def cov(x, y):
return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))
现在我有一个“数据框”D
(一个二维数组,其列是我的系列)我想矢量化 cov
函数产生协方差矩阵。现在,有一个显而易见的方法:
cov1 = jax.vmap(cov, in_axes=(None, 1))
cov2 = jax.vmap(cov1, in_axes=(1, None))
但这似乎有点笨拙。有没有“规范”的方式来做到这一点?
如果您想用 vmap
表达与嵌套 for
循环等效的逻辑,那么是的,它需要嵌套 vmap。我认为您所写的内容可能是您可以为这样的操作获得的规范,尽管如果使用装饰器编写可能会稍微清楚一些:
from functools import partial
@partial(jax.vmap, in_axes=(1, None))
@partial(jax.vmap, in_axes=(None, 1))
def cov(x, y):
return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))
但是,对于这个特定的函数,请注意,如果您愿意,您可以使用单个点积来表达同样的事情:
result = jnp.dot((x - x.mean(0)).T, (y - y.mean(0)))