使用 jax.random.normal 对具有特定均值和标准差的单变量高斯进行采样

sampling univariate gausssian with specific mean and standard deviation using jax.random.normal

我正在尝试从具有特定标准差和均值的高斯中采样,我知道 following function 是从均值为零且标准差等于 1 的高斯中采样:

import jax
from jax import random

key = random.PRNGKey(0)
mu = 20
std = 4

x1 = jax.random.normal(key, (1000,))

我可以通过以下方式调整均值:x1 = x1 + mu,但如何调整标准偏差?

这个

x1 = std * x1 + mu

给你想要的

以这种方式创建样本:

x1 = mu + std * jax.random.normal(key, (1000,))

如果这样做,样本的直方图将遵循预期的分布:

import jax
from jax import random
from jax.scipy.stats import norm
import matplotlib.pyplot as plt

key = random.PRNGKey(0)
mu = 20
std = 4

x1 = mu + std * jax.random.normal(key, (1000,))
plt.hist(x1, bins=50, density=True)

x = jnp.linspace(5, 35, 100)
y = norm.pdf(x, loc=mu, scale=std)
plt.plot(x, y)