多输入变量的 JAX 自定义 VJP 函数不适用于 NumPyro/HMC-NUTS
A JAX custom VJP function for multiple input variable does not work for NumPyro/HMC-NUTS
我正在尝试使用自定义 VJP(矢量-雅可比积)函数作为 numpyro 中 HMC-NUTS 的模型。我能够制作一个适用于 HMC-NUTS 的单变量函数,如下所示:
import jax.numpy as jnp
from jax import custom_vjp
@custom_vjp
def h(x):
return jnp.sin(x)
def h_fwd(x):
return h(x), jnp.cos(x)
def h_bwd(res, u):
cos_x = res
return (cos_x * u,)
h.defvjp(h_fwd, h_bwd)
这里我手动定义了h(x)=sin(x)。然后,我做了一个测试数据为
import numpy as np
np.random.seed(32)
sigin=0.3
N=20
x=np.sort(np.random.rand(N))*4*np.pi
data=hv(x)+np.random.normal(0,sigin,size=N)
test data
在这种情况下,我能够在 NumPyro 中执行 HMC-NUTS
import numpyro
import numpyro.distributions as dist
def model(x,y):
sigma = numpyro.sample('sigma', dist.Exponential(1.))
x0 = numpyro.sample('x0', dist.Uniform(-1.,1.))
#mu=jnp.sin(x-x0)
#mu=hv(x-x0)
mu=h(x-x0)
numpyro.sample('y', dist.Normal(mu, sigma), obs=y)
from jax import random
from numpyro.infer import MCMC, NUTS
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
num_warmup, num_samples = 1000, 2000
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rng_key_, x=x, y=data)
mcmc.print_summary()
有效。
sample: 100%|██████████| 3000/3000 [00:15<00:00, 193.84it/s, 3 steps of size 7.67e-01. acc. prob=0.92]
mean std median 5.0% 95.0% n_eff r_hat
sigma 0.35 0.06 0.34 0.26 0.45 1178.07 1.00
x0 0.07 0.11 0.07 -0.11 0.26 1243.73 1.00
Number of divergences: 0
但是,如果我将多变量函数定义为,
@custom_vjp
def h(x,A):
return A*jnp.sin(x)
def h_fwd(x, A):
res = (A*jnp.cos(x), jnp.sin(x))
return h(x,A), res
def h_bwd(res, u):
A_cos_x, sin_x = res
return (A_cos_x * u, sin_x * u)
h.defvjp(h_fwd, h_bwd)
然后执行 HMC-NUTS 作为
def model(x,y):
sigma = numpyro.sample('sigma', dist.Exponential(1.))
x0 = numpyro.sample('x0', dist.Uniform(-1.,1.))
A = numpyro.sample('A', dist.Exponential(1.))
mu=h(x-x0,A)
numpyro.sample('y', dist.Normal(mu, sigma), obs=y)
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
num_warmup, num_samples = 1000, 2000
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rng_key_, x=x, y=data)
mcmc.print_summary()
然后我得到一个错误,因为
TypeError: mul got incompatible shapes for broadcasting: (3,), (22,).
我怀疑我函数中的输出形状有误。但是,我试了好几次都改不了形状,怎么也搞不明白。
def model(x,y):
sigma = numpyro.sample('sigma', dist.Exponential(1.))
x0 = numpyro.sample('x0', dist.Uniform(-1.,1.))
A = numpyro.sample('A', dist.Exponential(1.))
hv=vmap(h,(0,None),0)
mu=hv(x-x0,A)
numpyro.sample('y', dist.Normal(mu, sigma), obs=y)
vmap 解决了这个问题。
我正在尝试使用自定义 VJP(矢量-雅可比积)函数作为 numpyro 中 HMC-NUTS 的模型。我能够制作一个适用于 HMC-NUTS 的单变量函数,如下所示:
import jax.numpy as jnp
from jax import custom_vjp
@custom_vjp
def h(x):
return jnp.sin(x)
def h_fwd(x):
return h(x), jnp.cos(x)
def h_bwd(res, u):
cos_x = res
return (cos_x * u,)
h.defvjp(h_fwd, h_bwd)
这里我手动定义了h(x)=sin(x)。然后,我做了一个测试数据为
import numpy as np
np.random.seed(32)
sigin=0.3
N=20
x=np.sort(np.random.rand(N))*4*np.pi
data=hv(x)+np.random.normal(0,sigin,size=N)
test data
在这种情况下,我能够在 NumPyro 中执行 HMC-NUTS
import numpyro
import numpyro.distributions as dist
def model(x,y):
sigma = numpyro.sample('sigma', dist.Exponential(1.))
x0 = numpyro.sample('x0', dist.Uniform(-1.,1.))
#mu=jnp.sin(x-x0)
#mu=hv(x-x0)
mu=h(x-x0)
numpyro.sample('y', dist.Normal(mu, sigma), obs=y)
from jax import random
from numpyro.infer import MCMC, NUTS
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
num_warmup, num_samples = 1000, 2000
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rng_key_, x=x, y=data)
mcmc.print_summary()
有效。
sample: 100%|██████████| 3000/3000 [00:15<00:00, 193.84it/s, 3 steps of size 7.67e-01. acc. prob=0.92]
mean std median 5.0% 95.0% n_eff r_hat
sigma 0.35 0.06 0.34 0.26 0.45 1178.07 1.00
x0 0.07 0.11 0.07 -0.11 0.26 1243.73 1.00
Number of divergences: 0
但是,如果我将多变量函数定义为,
@custom_vjp
def h(x,A):
return A*jnp.sin(x)
def h_fwd(x, A):
res = (A*jnp.cos(x), jnp.sin(x))
return h(x,A), res
def h_bwd(res, u):
A_cos_x, sin_x = res
return (A_cos_x * u, sin_x * u)
h.defvjp(h_fwd, h_bwd)
然后执行 HMC-NUTS 作为
def model(x,y):
sigma = numpyro.sample('sigma', dist.Exponential(1.))
x0 = numpyro.sample('x0', dist.Uniform(-1.,1.))
A = numpyro.sample('A', dist.Exponential(1.))
mu=h(x-x0,A)
numpyro.sample('y', dist.Normal(mu, sigma), obs=y)
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
num_warmup, num_samples = 1000, 2000
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rng_key_, x=x, y=data)
mcmc.print_summary()
然后我得到一个错误,因为
TypeError: mul got incompatible shapes for broadcasting: (3,), (22,).
我怀疑我函数中的输出形状有误。但是,我试了好几次都改不了形状,怎么也搞不明白。
def model(x,y):
sigma = numpyro.sample('sigma', dist.Exponential(1.))
x0 = numpyro.sample('x0', dist.Uniform(-1.,1.))
A = numpyro.sample('A', dist.Exponential(1.))
hv=vmap(h,(0,None),0)
mu=hv(x-x0,A)
numpyro.sample('y', dist.Normal(mu, sigma), obs=y)
vmap 解决了这个问题。