"Spotting" 以编程方式分布的概率密度函数(符号工具箱)

"Spotting" probability density functions of distributions programmatically (Symbolic Toolbox)

我有一个联合概率密度f(x,y,z),我希望找到条件分布X|Y=y,Z=z,这相当于将x视为数据,yz作为参数(常量)。

例如,如果我有 X|Y=y,Z=z 作为 N(1-2y,3z^2+2) 的 pdf,函数将是:

syms x y z
f(y,z) = 1/sqrt(2*pi*(3*z^2+2)) * exp(-1/(2*(3*z^2+2)) * (x-(1-2*y))^2);

我想将它与以下内容进行比较:

syms mu s L a b
Normal(mu,s) = (1/sqrt(2*pi*s^2)) * exp(-1/(2*s^2) * (x-mu)^2);
Exponential(L) = L * exp(-L*x);
Gamma(a,b) = (b^a / gamma(a)) * x^(a-1)*exp(-b*x);
Beta(a,b) = (1/beta(a,b)) * x^(a-1)*(1-x)^(b-1);

问题

我如何编写一个程序 whichDistribution 来打印这四个 f 中的哪一个等同于(按比例)变量 x ,参数是什么?例如。 fx如上,分布为Normal, mu=1-2*y, s=3*z^2+2.

注意:并不总是有唯一的解决方案,因为有些分布是等价的(例如 Gamma(1,L)==Exponential(L)


期望的产出

syms x y z
f = 1/sqrt(2*pi*(3*z^2+2)) * exp(-1/(2*(3*z^2+2)) * (x-(1-2*y))^2)
whichDistribution(f,x) %Conditional X|Y,Z
% Normal(1-2*y,3*z^2+2)
syms x y
f = y^(1/2)*exp(-(x^2)/2 - y/2 * (1+(4-x)^2+(6-x)^2)) % this is not a pdf because it is missing a constant of proportionality, but it should still work

whichDistribution(f,x)  %Conditional X|Y
% Normal(10*y/(2*y+1), 1/(2*y+1))

whichDistribution(f,y)  %Conditional Y|X
% Gamma(3/2, x^2 - 10*x + 53/2)
f = exp(-x) %also missing a constant of proportionality
whichDistribution(f,x)
% Exponential(1)
f = 1/(2*pi)*exp(-(x^2)/2 - (y^2)/2)
whichDistribution(f,x)
% Normal(0,1)
whichDistribution(f,y)
% Normal(0,1)

到目前为止我尝试过的:

  1. 使用solve():
q = solve(f(y,z) == Normal(mu,s), mu, s)

这给出了错误的结果,因为参数不能依赖于 x:

>> q.mu
ans =
(z1^2*(log((2^(1/2)*exp(x^2/(2*z1^2) - (x + 2*y - 1)^2/(6*z^2 + 4)))/(2*pi^(1/2)*(3*z^2 + 2)^(1/2))) + pi*k*2i))/x
>> q.s
ans =
z1
  1. 尝试使用我编写的 propto() 函数将 f(y,z) 简化为比例(在 x 变量中):
>> propto(f(y,z),x)
ans =
exp(-(x*(x + 4*y - 2))/(2*(3*z^2 + 2)))

>> propto(Normal(mu,s),x)
ans =
exp((x*(2*mu - x))/(2*s^2))

这几乎是对的,因为很容易发现 s^2=3*z^2 + 22*mu=-(4*y - 2),但我不知道如何以编程方式推断这一点。



如果它有用:propto(f,x) 尝试通过将 f 除以 f 的 children 来简化 f,这不涉及 x,然后输出children个数最少的形式。这是例程:

function out = propto(f,x)
oldf = f;
newf = propto2(f,x);
while (~strcmp(char(oldf),char(newf))) % if the form of f changed, do propto2 again. When propto2(f) == f, stop
    oldf = newf;
    newf = propto2(oldf,x);
end
out = newf;
end
function out = propto2(f,x)
t1 = children(expand(f)); % expanded f
i1 = ~has([t1{:}],x);
out1 = simplify(f/prod([t1{i1}])); % divides expanded f by terms that do not involve x

t2 = children(f); % unexpanded f
i2 = ~has([t2{:}],x);
out2 = simplify(f/prod([t2{i2}])); % divides f by terms that do not involve x

A = [f, symlength(f); out1, symlength(out1); out2, symlength(out2)];
A = sortrows(A,2); % outputs whichever form has the fewest number of children
out = A(1,1);
end
function L = symlength(f)
% counts the number of children of f by repeatingly applying children() to itself
t = children(f);
t = [t{:}];
L = length(t);
if (L == 1)
    return
end
oldt = f;
while(~strcmp(char(oldt),char(t)))
    oldt = t;
    t = children(t);
    t = [t{:}];
    t = [t{:}];
end
L = length(t);
end

编辑:添加了所需的输出

edit2:阐明了所需的功能

我已经使用 Symbolic Toolbox 中的 solve() 解决了我自己的问题。我原来的方法有两个问题:我需要为 n 参数设置 n 联立方程,并且 solve() 不能很好地处理指数:

solve(f(3) == g(3), f(4) == g(4), mu,s)

没有解,但是

logf(x) = feval(symengine,'simplify',log(f),'IgnoreAnalyticConstraints');
logg(x) = feval(symengine,'simplify',log(g),'IgnoreAnalyticConstraints');
solve(logf(3) == logg(3), logf(4) == logg(4), mu,s)

产生很好的解决方案。


解决方案

给定 f(x),对于每个 PDF g(x),我们尝试同时求解

log(f(r1)) == log(g(r1)) and log(f(r2)) == log(g(r2))

对于一些简单的不相等的数字 r1r2。然后输出 g 其解决方案具有最低的复杂性。


密码是:

function whichDist(f,x)

syms mu s L a b x0 x1 x2 v n p g

f = propto(f,x); % simplify up to proportionality
logf(x) = feval(symengine,'simplify',log(f),'IgnoreAnalyticConstraints');
Normal(mu,s,x) = propto((1/sqrt(2*pi*s)) * exp(-1/(2*s) * (x-mu)^2),x);
Exponential(L,x) = exp(-L*x);
Gamma(a,b,x) = x^(a-1)*exp(-b*x);
Beta(a,b,x) = x^(a-1)*(1-x)^(b-1);
ChiSq(v,x) = x^(v/2 - 1) * exp(-x/2);
tdist(v,x) = (1+x^2 / v)^(-(v+1)/2);
Cauchy(g,x0,x) = 1/(1+((x-x0)/g)^2);

logf = logf(x);
best_sol = {'none', inf};
r1 = randi(10); r2 = randi(10); r3 = randi(10);
while (r1 == r2 || r2 == r3 || r1 == r3) r1 = randi(10); r2 = randi(10); r3 = randi(10); end

%% check Exponential:
if (propto(logf,x) == x) % pdf ~ exp(K*x), can read off Lambda directly
    soln = -logf/x;
    if (~has(soln,x)) % any solution can't depend on x
        fprintf('\nExponential: rate L = %s\n\n', soln);
        return
    end
end

%% check Chi-sq:
if (propto(logf + x/2, log(x)) == log(x)) % can read off v directly
    soln = 2*(1+(logf + x/2) / log(x));
    if (~has(soln,x))
        dof = feval(symengine,'simplify',soln,'IgnoreAnalyticConstraints');
        fprintf('\nChi-Squared: v = %s\n\n', dof);
        return
    end
end

%% check t-dist:
h1 = propto(logf,x);
h = simplify(exp(h1) - 1);
if (propto(h,x^2) == x^2) % pdf ~ exp(K*x), can read off Lambda directly
    soln = simplify(x^2 / h);
    if (~has(soln,x))
        fprintf('\nt-dist: v = %s\n\n', soln);
        return
    end
end
h = simplify(exp(-h1) - 1); % try again if propto flipped a sign
if (propto(h,x^2) == x^2) % pdf ~ exp(K*x), can read off Lambda directly
    soln = simplify(x^2 / h);
    if (~has(soln,x))
        fprintf('\nt-dist: v = %s\n\n', soln);
        return
    end
end

%% check Normal:
logn(x) = feval(symengine,'simplify',log(Normal(mu,s,x)),'IgnoreAnalyticConstraints');
% A = (x - propto(logf/x, x))/2;
% B = simplify(-x/(logf/x - mu/s)/2);
% if (~has(A,x) && ~has(B,x))
%     fprintf('Normal: mu = %s, s^2 = %s', A, B);
%     return
% end
logf(x) = logf;
try % attempt to solve the equation
    % solve simultaneously for two random non-equal integer values r1,r2
    qn = solve(logf(r1) == logn(r1), logf(r2) == logn(r2), mu, s);
catch error
end
if (exist('qn','var')) % if solve() managed to run
    if (~isempty(qn.mu) && ~isempty(qn.s) && ~any(has([qn.mu,qn.s],x))) % if solution exists
        complexity = symlength(qn.mu) + symlength(qn.s);
        if complexity < best_sol{2} % store best solution so far
            best_sol{1} = sprintf('Normal: mu = %s, s^2 = %s', qn.mu, qn.s);
            best_sol{2} = complexity;
        end
    end
end


%% check Cauchy:
logcau(x) = feval(symengine,'simplify',log(Cauchy(g,x0,x)),'IgnoreAnalyticConstraints');
f(x) = f;
try
    qcau = solve(f(r1) == Cauchy(g,x0,r1), f(r2) == Cauchy(g,x0,r2), g, x0);
catch error
end
if (exist('qcau','var'))
    if (~isempty(qcau.g) && ~isempty(qcau.x0) && ~any(has([qcau.g(1),qcau.x0(1)],x)))
        complexity = symlength(qcau.g(1)) + symlength(qcau.x0(1));
        if complexity < best_sol{2}
            best_sol{1} = sprintf('Cauchy: g = %s, x0 = %s', qcau.g(1), qcau.x0(1));
            best_sol{2} = complexity;
        end
    end
end
f = f(x);

%% check Gamma:
logg(x) = feval(symengine,'simplify',log(Gamma(a,b,x)),'IgnoreAnalyticConstraints');
t = children(logf); t = [t{:}];
if (length(t) == 2)
    if (propto(t(1),log(x)) == log(x) && propto(t(2),x) == x)
        soln = [t(1)/log(x) + 1, -t(2)/x];
        if (~any(has(soln,x)))
            fprintf('\nGamma: shape a = %s, rate b = %s\n\n',soln);
            return
        end
    elseif (propto(t(2),log(x)) == log(x) && propto(t(1),x) == x)
        soln = [t(2)/log(x) + 1, -t(1)/x];
        if (~any(has(soln,x)))
            fprintf('\nGamma: shape a = %s, rate b = %s\n\n',soln);
            return
        end
    end
end
logf(x) = logf;
try % also try using solve(), just in case.
    qg = solve(logf(r1) == logg(r1), logf(r2) == logg(r2), a, b);
catch error
end
if (exist('qg','var'))
    if (~isempty(qg.a) && ~isempty(qg.b) && ~any(has([qg.a,qg.b],x)))
        complexity = symlength(qg.a) + symlength(qg.b);
        if complexity < best_sol{2}
            best_sol{1} = sprintf('Gamma: shape a = %s, rate b = %s', qg.a, qg.b);
            best_sol{2} = complexity;
        end
    end
end
logf = logf(x);

%% check Beta:
B = feval(symengine,'simplify',log(propto(f,x-1)),'IgnoreAnalyticConstraints');
if (propto(B,log(x-1)) == log(x-1))
    B = B / log(x-1) + 1;
    A = f / (x-1)^(B-1);
    A = feval(symengine,'simplify',log(abs(A)),'IgnoreAnalyticConstraints');
    if (propto(A,log(abs(x))) == log(abs(x)))
        A = A / log(abs(x)) + 1;
        if (~any(has([A,B],x)))
            fprintf('\nBeta1: a = %s, b = %s\n\n', A, B);
            return
        end
    end
elseif (propto(B,log(1-x)) == log(1-x))
    B = B / log(1-x);
    A = simplify(f / (1-x)^(B-1));
    A = feval(symengine,'simplify',log(A),'IgnoreAnalyticConstraints');
    if (propto(A,log(x)) == log(x))
        A = A / log(x) + 1;
        if (~any(has([A,B],x)))
            fprintf('\nBeta1: a = %s, b = %s\n\n', A, B);
            return
        end
    end
end

%% Print solution with lowest complexity
fprintf('\n%s\n\n', best_sol{1});

end

测试:

>> syms x y z
>> f = y^(1/2)*exp(-(x^2)/2 - y/2 * (1+(4-x)^2+(6-x)^2))
>> whichDist(f,x)
Normal: mu = (10*y)/(2*y + 1), s^2 = 1/(2*y + 1)
>> whichDist(f,y)
Gamma: a = 3/2, b = x^2 - 10*x + 53/2
>> Beta(a,b,x) = propto((1/beta(a,b)) * x^(a-1)*(1-x)^(b-1), x);
>> f = Beta(1/z + 7*y/(1-sqrt(z)), z/y + 1/(1-z), x)
Beta: a = -(7*y*z - z^(1/2) + 1)/(z*(z^(1/2) - 1)), b = -(y + z - z^2)/(y*(z - 1))

全部正确。

如果参数是数字,有时会出现虚假答案:

whichDist(Beta(3,4,x),x)
Beta: a = -(pi*log(2)*1i + pi*log(3/10)*1i - log(2)*log(3/10) + log(2)*log(7/10) - log(3/10)*log(32) + log(2)*log(1323/100000))/(log(2)*(log(3/10) - log(7/10))), b = (pi*log(2)*1i + pi*log(7/10)*1i + log(2)*log(3/10) - log(2)*log(7/10) - log(7/10)*log(32) + log(2)*log(1323/100000))/(log(2)*(log(3/10) - log(7/10)))

所以还有改进的余地,我仍然会奖励比这更好的解决方案。


编辑:添加了更多发行版。改进了 Gamma 和 Beta 分布识别,无需 solve().

即可直接发现它们