高斯混合模型的参数估计

Parametric estimation of a Gaussian Mixture Model

我正在尝试训练一个模型来估计 GMM。但是,GMM的均值每次都是根据mean_placement参数计算的。我正在按照提供的解决方案,我将复制并粘贴原始代码:

import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets as datasets

import torch
from torch import nn
from torch import optim
import torch.distributions as D

num_layers = 8
weights = torch.ones(8,requires_grad=True)
means = torch.tensor(np.random.randn(8,2),requires_grad=True)
stdevs = torch.tensor(np.abs(np.random.randn(8,2)),requires_grad=True)

parameters = [weights, means, stdevs]
optimizer1 = optim.SGD(parameters, lr=0.001, momentum=0.9)

num_iter = 10001
for i in range(num_iter):
    mix = D.Categorical(weights)
    comp = D.Independent(D.Normal(means,stdevs), 1)
    gmm = D.MixtureSameFamily(mix, comp)

    optimizer1.zero_grad()
    x = torch.randn(5000,2)#this can be an arbitrary x samples
    loss2 = -gmm.log_prob(x).mean()#-densityflow.log_prob(inputs=x).mean()
    loss2.backward()
    optimizer1.step()

    print(i, loss2)

我想做的是:

num_layers = 8
weights = torch.ones(8,requires_grad=True)
means_coef = torch.tensor(10.,requires_grad=True)
means = torch.tensor(torch.dstack([torch.linspace(1,means_coef.detach().item(),8)]*2).squeeze(),requires_grad=True)
stdevs = torch.tensor(np.abs(np.random.randn(8,2)),requires_grad=True)
parameters = [means_coef]
optimizer1 = optim.SGD(parameters, lr=0.001, momentum=0.9)

num_iter = 10001
for i in range(num_iter):
    means = torch.tensor(torch.dstack([torch.linspace(1,means_coef.detach().item(),8)]*2).squeeze(),requires_grad=True)

    mix = D.Categorical(weights)
    comp = D.Independent(D.Normal(means,stdevs), 1)
    gmm = D.MixtureSameFamily(mix, comp)

    optimizer1.zero_grad()
    x = torch.randn(5000,2)#this can be an arbitrary x samples
    loss2 = -gmm.log_prob(x).mean()#-densityflow.log_prob(inputs=x).mean()
    loss2.backward()
    optimizer1.step()

    print(i, means_coef)
    print(means_coef)


然而,在这种情况下,参数不会更新,梯度值始终为 None。有什么解决办法吗?

根据您的说明,我有 re-written 您的型号。 如果你 运行 它你可以看到模型优化后所有参数都在变化。我还在最后提供了模型图。如果你想制作一个新的,你可以根据需要简单地修改 GMM class。

import numpy as np
import matplotlib.pyplot as plt
import sklearn.datasets as datasets

import torch
from torch import nn
from torch import optim
import torch.distributions as D

class GMM(nn.Module):
    
    def __init__(self, weights, base, scale, n_cell=8, shift=0, dim=2):
        super(GMM, self).__init__()
        self.weight = nn.Parameter(weights)
        self.base = nn.Parameter(base)
        self.scale = nn.Parameter(scale)
        self.grid = torch.arange(1, n_cell+1)
        self.shift = shift
        self.n_cell = n_cell
        self.dim = dim
    
    def trsf_grid(self):
        trsf = (
            torch.log(self.scale * self.grid + self.shift) 
            / torch.log(self.base)
            ).reshape(-1, 1)
        return trsf.expand(self.n_cell, self.dim)
    
    def forward(self, x, std):
        means = self.trsf_grid()
        mix = D.Categorical(self.weight)
        comp = D.Independent(D.Normal(means, std), 1)
        gmm = D.MixtureSameFamily(mix, comp)
        return -gmm.log_prob(x).mean()

if __name__ == "__main__":
    weight = torch.ones(8)
    base = torch.tensor(3.)
    scale = torch.tensor(1.)
    stds = torch.tensor(np.abs(np.random.randn(8,2)),requires_grad=False)
    model = GMM(weight, base, scale)
    print(list(model.parameters()))
    
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    for i in range(1000):
        optimizer.zero_grad()
        x = torch.randn(5000,2)
        loss = model(x, stds)
        loss.backward()
        optimizer.step()
        
    print(list(model.parameters()))

在我的例子中它返回了以下参数:

[Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True), Parameter containing:
tensor(3., requires_grad=True), Parameter containing:
tensor(1., requires_grad=True)]

[Parameter containing:
tensor([0.7872, 1.1010, 1.3390, 1.3757, 0.5122, 0.2884, 1.2597, 0.7597],
       requires_grad=True), Parameter containing:
tensor(3.3207, requires_grad=True), Parameter containing:
tensor(0.2814, requires_grad=True)]

这确实表明参数正在更新。 您还可以看到下面的计算图: