基于 TensorFlow Probability 中的另一个随机变量从张量中选择一个法线
Selecting one normal from a tensor based on another random variable in TensorFlow Probability
我正在尝试 select 根据分类分布的输出从一系列正态分布中提取单个样本,但似乎无法想出完全正确的方法。使用以下内容:
tfp.distributions.JointDistributionSequential([
tfp.distributions.Categorical(probs=[0, 0, 1/2, 1/2]),
lambda c: tfp.distributions.Normal([0, 1, -10, 30], 1)[..., c]
])
Returns 正是我想要的单个案例,但是如果我想要一次多个样本,这就会中断(因为 c 变成了一个 numpy 数组而不是一个整数。这可能吗?如果是的话,应该怎么做我去做吗?
(我也尝试过使用 OneHotCategorical 和相乘,但那根本不起作用!)
还有一个 distribution:
tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=[0, 0, .5, .5]),
components_distribution=tfd.Normal([0, 1, -10, 30], 1))
如果您不想像 Brian 建议的那样使用 MixtureSameFamily
,您可以这样做:
tfp.distributions.JointDistributionSequential([
tfp.distributions.Categorical(probs=[0, 0, 1/2, 1/2]),
lambda c: tfp.distributions.Normal(tf.gather([0., 1, -10, 30], c), 1)
])
注意我需要在 gather 中的位置添加一个 .
以避免 dtype 错误。
到这里,我们最终要做的是
- 从
Categorical
中抽取 n
个样本
- 构造一批
n
Normal
s,其locs是通过索引n
次到locs 的4向量中获得的
- 从
n
批次 Normal
中抽样。
之前的方法不行,因为Distribution
切片不支持这种"fancy indexing" It would be cool if we did! TF doesn't support it in general, for reasons。
我正在尝试 select 根据分类分布的输出从一系列正态分布中提取单个样本,但似乎无法想出完全正确的方法。使用以下内容:
tfp.distributions.JointDistributionSequential([
tfp.distributions.Categorical(probs=[0, 0, 1/2, 1/2]),
lambda c: tfp.distributions.Normal([0, 1, -10, 30], 1)[..., c]
])
Returns 正是我想要的单个案例,但是如果我想要一次多个样本,这就会中断(因为 c 变成了一个 numpy 数组而不是一个整数。这可能吗?如果是的话,应该怎么做我去做吗?
(我也尝试过使用 OneHotCategorical 和相乘,但那根本不起作用!)
还有一个 distribution:
tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=[0, 0, .5, .5]),
components_distribution=tfd.Normal([0, 1, -10, 30], 1))
如果您不想像 Brian 建议的那样使用 MixtureSameFamily
,您可以这样做:
tfp.distributions.JointDistributionSequential([
tfp.distributions.Categorical(probs=[0, 0, 1/2, 1/2]),
lambda c: tfp.distributions.Normal(tf.gather([0., 1, -10, 30], c), 1)
])
注意我需要在 gather 中的位置添加一个 .
以避免 dtype 错误。
到这里,我们最终要做的是
- 从
Categorical
中抽取 - 构造一批
n
Normal
s,其locs是通过索引n
次到locs 的4向量中获得的
- 从
n
批次Normal
中抽样。
n
个样本
之前的方法不行,因为Distribution
切片不支持这种"fancy indexing" It would be cool if we did! TF doesn't support it in general, for reasons。