最小化 scipy.stats.multivariate_normal.logpdf 关于协方差

Minimizing scipy.stats.multivariate_normal.logpdf with respect to covariance

我有一个 python 脚本,我在其中使用 scipy 的 multivariate_normal.log_pdf 计算双变量数据样本的正常对数似然函数的值。我假设样本均值和方差的值,只留下变量之间的样本相关性作为未知数,

from scipy.stats import multivariate_normal
from scipy.optimize import minimize

VAR_X = 0.4
VAR_Y = 0.32
MEAN_X = 1
MEAN_Y = 1.2

def log_likelihood_function(x, data):
    log_likelihood = 0
    sigma = [ [VAR_X, x[0]], [x[0], VAR_Y] ]
    mu = [ MEAN_X, MEAN_Y ]
    for point in data:
        log_likelihood += multivariate_normal.logpdf(x=point, mean=mu, cov=sigma)
    return log_likelihood

if __name__ == "__main__":
    some_data = [ [1.1, 2.0], [1.2, 1.9], [0.8, 0.2], [0.7, 1.3] ]
    
    guess = [ 0 ] 

    # maximize log-likelihood by minimizing the negative 
    likelihood = lambda x: (-1)*log_likelihood_function(x, some_data)
    
    result = minimize(fun = likelihood, x0 = guess, options = {'disp': True}, method="SLSQP")

    print(result)

无论我的猜测是什么,这个脚本都会可靠地抛出 ValueError,

ValueError: the input matrix must be positive semidefinite

现在,根据我的估计,问题似乎是 scipy.optimize.minimize 是猜测创建非正定协方差矩阵的值。所以我需要一种方法来确保最小化算法丢弃问题域之外的值。我想给最小化调用添加一个约束,

## make the determinant always positive
def positive_definite_constraint(x):
    return VAR_X*VAR_Y - x*x

这基本上是协方差矩阵的 Slyvester Criteron 并且可以确保矩阵是正定的(因为我们知道方差总是正的,所以不需要检查该条件)但它看起来像 scipy.optimize.minimize 在确定是否满足约束之前评估 objective 函数(这似乎是一个设计缺陷;在受限域中搜索解决方案而不是搜索所有可能的解决方案和然后确定是否满足约束?不过,我可能对评估顺序有误。)

我不确定如何进行。我意识到我通过参数化协方差矩阵然后最小化该参数化来稍微扩展 scipy.optimize 的目的,我知道有更好的方法来计算正常样本的相关性,但我很感兴趣在这个问题中,因为它泛化到非正态分布。

有什么建议吗?有没有更好的方法来解决这个问题?

你走在正确的轨道上。请注意,您的确定性约束简化为优化变量的简单界限,即 -∞ <= x[0] <= VAR_X*VAR_Y。与更一般的约束相比,变量边界在内部得到更好的处理,所以我推荐这样的东西:

bounds = [(None, VAR_X*VAR_Y)]
res = minimize(fun = likelihood, x0 = guess, bounds=bounds, options = {'disp': True}, method="SLSQP")

这给了我:

     fun: 6.610504611834715
     jac: array([-0.0063166])
 message: 'Optimization terminated successfully'
    nfev: 9
     nit: 4
    njev: 4
  status: 0
 success: True
       x: array([0.12090069])