混合多元高斯分布张量流概率

Mixture of multivariate gaussian distribution tensorflow probability

如标题所述,我正在尝试使用 tensorflow 概率包创建多元正态分布的混合。

在我最初的项目中,我从神经网络的输出中获取分类权重、位置和方差。但是,在创建图表时,出现以下错误:

components[0] batch shape must be compatible with cat shape and other component batch shapes

我使用占位符重现了同样的问题:

import tensorflow as tf
import tensorflow_probability as tfp # dist= tfp.distributions 

tf.compat.v1.disable_eager_execution()
sess = tf.compat.v1.InteractiveSession()

l1 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_1')
l2 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_2')

log_std = tf.compat.v1.get_variable('log_std', [1, 2], dtype=tf.float32,
                                          initializer=tf.constant_initializer(1.0),
                                          trainable=True)

mix = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None,1], name='weights')

cat = tfp.distributions.Categorical(probs=[mix, 1.-mix])
components = [
    tfp.distributions.MultivariateNormalDiag(loc=l1, scale_diag=tf.exp(log_std)),
    tfp.distributions.MultivariateNormalDiag(loc=l2, scale_diag=tf.exp(log_std)),
]

bimix_gauss = tfp.distributions.Mixture(
  cat=cat,
  components=components)

所以,我的问题是,我做错了什么?我调查了错误,似乎 tensorshape_util.is_compatible_with 是引发错误的原因,但我不明白为什么。

谢谢!

您似乎向 tfp.distributions.Categorical 提供了错误的输入。它的 probs 参数应该是 [batch_size, cat_size] 的形状,而你提供的参数是 [cat_size, batch_size, 1]。所以也许尝试用 tf.concat([mix, 1-mix], 1) 参数化 probs

您的log_std 也可能有问题,它与 l1l2 的形状不同。如果 MultivariateNormalDiag 没有正确广播它,请尝试将其形状指定为 (None, 2) 或平铺它,以便它的第一个维度对应于您的位置参数。

当组件类型相同时,MixtureSameFamily 应该更高效。

你只传递了一个分类实例(带.batch_shape [b1,b2,...,bn])和一个MVNDiag实例(带.batch_shape [b1,b2 ,...,bn,numcats]).

只有两个类,不知道伯努利是否可行?