如何使用 Tensorflow 从范围内的正态分布随机值中采样?
How to sample from a normal distribution random values inside a range using Tensorflow?
我有两个变量 mean
和 stddev
,它们是形状为 (1,) 的张量,它们表示许多具有均值的正态分布,比如 mean[i] 和标准偏差 stddev[i] .
我想从这些分布中为每个人在 [low
,up
,] 范围内采样一个值,然后我想获得采样值的对数概率。
从 docs 我发现 experimental_sample_and_log_prob
方法几乎适合我,因为它不会对我想要的值范围(低,高)内的元素进行采样。
所以我编写了几行代码,但它自然不能很好地工作,而且计算量太大。
import tensorflow as tf
from tensorflow_probability import distributions as tfd
def sample_and_log_prob(dist, up, down):
samples = dist.sample()
accepted = False
print("Is {} accepted? {}".format(samples, accepted))
while not accepted:
# sample < up
cond1 = tf.less_equal(samples, up)
# sample > down
cond2 = tf.greater_equal(samples, down)
# if down < sample < up
accepted = tf.logical_and(cond1, cond2)
samples = tf.where(
tf.logical_not(accepted),
samples,
dist.sample())
print("Is {} accepted? {}".format(samples, accepted))
return samples, dist.log_prob(samples)
distribution = tfd.Normal(
loc=mean ,
scale=stddev,
validate_args=True,
allow_nan_stats=False)
samples, log_probs = sample_and_log_prob(distribution, up=-1, down=1)
有什么解决办法吗?
听起来你想要一个 TruncatedNormal
分布。
我有两个变量 mean
和 stddev
,它们是形状为 (1,) 的张量,它们表示许多具有均值的正态分布,比如 mean[i] 和标准偏差 stddev[i] .
我想从这些分布中为每个人在 [low
,up
,] 范围内采样一个值,然后我想获得采样值的对数概率。
从 docs 我发现 experimental_sample_and_log_prob
方法几乎适合我,因为它不会对我想要的值范围(低,高)内的元素进行采样。
所以我编写了几行代码,但它自然不能很好地工作,而且计算量太大。
import tensorflow as tf
from tensorflow_probability import distributions as tfd
def sample_and_log_prob(dist, up, down):
samples = dist.sample()
accepted = False
print("Is {} accepted? {}".format(samples, accepted))
while not accepted:
# sample < up
cond1 = tf.less_equal(samples, up)
# sample > down
cond2 = tf.greater_equal(samples, down)
# if down < sample < up
accepted = tf.logical_and(cond1, cond2)
samples = tf.where(
tf.logical_not(accepted),
samples,
dist.sample())
print("Is {} accepted? {}".format(samples, accepted))
return samples, dist.log_prob(samples)
distribution = tfd.Normal(
loc=mean ,
scale=stddev,
validate_args=True,
allow_nan_stats=False)
samples, log_probs = sample_and_log_prob(distribution, up=-1, down=1)
有什么解决办法吗?
听起来你想要一个 TruncatedNormal
分布。