矢量化嵌套 vmap
Vectorise nested vmap
这是我的一些数据:
import jax.numpy as jnp
import numpyro.distributions as dist
import jax
xaxis = jnp.linspace(-3, 3, 5)
yaxis = jnp.linspace(-3, 3, 5)
我想 运行 函数
def func(x, y):
return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))
来自 xaxis
和 yaxis
的每对值。
这是一种“缓慢”的方法:
results = np.zeros((len(xaxis), len(yaxis)))
for i in range(len(xaxis)):
for j in range(len(yaxis)):
results[i, j] = func(xaxis[i], yaxis[j])
有效,但速度很慢。
所以这是一个矢量化的方法:
jax.vmap(lambda axis: jax.vmap(func, (None, 0))(axis, yaxis))(xaxis)
快多了,但很难阅读。
是否有编写矢量化版本的简洁方法?我可以用一个 vmap
来完成,而不必将一个嵌套在另一个中吗?
编辑
另一种方式是
jax.vmap(func)(xmesh.flatten(), ymesh.flatten()).reshape(len(xaxis), len(yaxis)).T
但还是很乱。
我相信 与您的问题非常相似;要使用 vmap 复制嵌套 for 循环的逻辑,需要嵌套 vmap。
使用 jax.vmap
的最干净的方法可能是这样的:
from functools import partial
@partial(jax.vmap, in_axes=(0, None))
@partial(jax.vmap, in_axes=(None, 0))
def func(x, y):
return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))
func(xaxis, yaxis)
这里的另一种选择是使用 jnp.vectorize
API(通过多个 vmap 实现),在这种情况下,您可以这样做:
print(jnp.vectorize(func)(xaxis[:, None], yaxis))
这是我的一些数据:
import jax.numpy as jnp
import numpyro.distributions as dist
import jax
xaxis = jnp.linspace(-3, 3, 5)
yaxis = jnp.linspace(-3, 3, 5)
我想 运行 函数
def func(x, y):
return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))
来自 xaxis
和 yaxis
的每对值。
这是一种“缓慢”的方法:
results = np.zeros((len(xaxis), len(yaxis)))
for i in range(len(xaxis)):
for j in range(len(yaxis)):
results[i, j] = func(xaxis[i], yaxis[j])
有效,但速度很慢。
所以这是一个矢量化的方法:
jax.vmap(lambda axis: jax.vmap(func, (None, 0))(axis, yaxis))(xaxis)
快多了,但很难阅读。
是否有编写矢量化版本的简洁方法?我可以用一个 vmap
来完成,而不必将一个嵌套在另一个中吗?
编辑
另一种方式是
jax.vmap(func)(xmesh.flatten(), ymesh.flatten()).reshape(len(xaxis), len(yaxis)).T
但还是很乱。
我相信
使用 jax.vmap
的最干净的方法可能是这样的:
from functools import partial
@partial(jax.vmap, in_axes=(0, None))
@partial(jax.vmap, in_axes=(None, 0))
def func(x, y):
return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))
func(xaxis, yaxis)
这里的另一种选择是使用 jnp.vectorize
API(通过多个 vmap 实现),在这种情况下,您可以这样做:
print(jnp.vectorize(func)(xaxis[:, None], yaxis))