高斯之间的交集

Intersection between Gaussian

我只是想绘制两个高斯曲线并找到交点。我有以下代码。虽然它没有绘制确切的交叉点,但我真的不明白为什么。好像只是稍微偏离了一点,但是如果我们对减去的高斯函数进行对数,我就完成了派生的解决方案,是的,它似乎应该是正确的。谁能帮忙?非常感谢!

import numpy as np 
import matplotlib.pyplot as plt 

def plot_normal(x, mean = 0, sigma = 1):
    return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))

# found online
def solve_gasussians(m1, s1, m2, s2):
  a = 1.0/(2.0*s1**2) - 1.0/(2.0*s2**2)
  b = m2/(s2**2) - m1/(s1**2)
  c = m1**2 /(2*s1**2) - m2**2 / (2.0*s2**2) - np.log(s2/s1)
  return np.roots([a,b,c])

s1 = np.linspace(0, 10,300)
s2 = np.linspace(0, 14, 300)

solved_val = solve_gasussians(5.0, 0.5, 7.0, 1.0)
print solved_val
solved_val = solved_val[0]
plt.figure('Baseline Distributions')
plt.title('Baseline Distributions')
plt.xlabel('Response Rate')
plt.ylabel('Probability')
plt.plot(s1, plot_normal(s1, 5.0, 0.5),'r', label='s1')
plt.plot(s2, plot_normal(s2, 7.0, 1.0),'b', label='s2')
plt.plot(solved_val, plot_normal(solved_val, 7.0, 1.0), 'mo')
plt.legend()
plt.show()

不知道你的代码错在哪里。但我想我找到了你借用的代码并做了你需要的部分调整。

import numpy as np 
import matplotlib.pyplot as plt 
from scipy.stats import norm

def solve(m1,m2,std1,std2):
  a = 1/(2*std1**2) - 1/(2*std2**2)
  b = m2/(std2**2) - m1/(std1**2)
  c = m1**2 /(2*std1**2) - m2**2 / (2*std2**2) - np.log(std2/std1)
  return np.roots([a,b,c])

m1 = 5
std1 = 0.5
m2 = 7
std2 = 1

result = solve(m1,m2,std1,std2)

x = np.linspace(-5,9,10000)
plot1=plt.plot(x,[norm.pdf(_,m1,std1) for _ in x])
plot2=plt.plot(x,[norm.pdf(_,m2,std2) for _ in x])
plot3=plt.plot(result[0],norm.pdf(result[0],m1,std1) ,'o')

plt.show()

我将主动提供两条建议,它们可能会让您的生活更轻松(就像他们为我做的那样):

  • 当您调整代码时,尝试进行小的增量更改,并检查代码在每一步是否仍然有效。
  • 查找现有的免费图书馆。在这种情况下,scipy 中的 norm 可以很好地替代原始代码中使用的内容。

错误就在这里。这一行:

def plot_normal(x, mean = 0, sigma = 1):
  return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))

应该是这样的:

def plot_normal(x, mean = 0, sigma = 1):
  return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))

您忘记了sqrt

如果可用的话,最好使用已有的普通 pdf,例如:

import scipy.stats
def plot_normal(x, mean = 0, sigma = 1):
  return scipy.stats.norm.pdf(x,loc=mean,scale=sigma)

也可以精确求解交叉点。 This answer 提供高斯交叉点的根的二次方程。使用最大值求解 x 给出以下表达式。其中,虽然复杂,但不依赖于迭代方法,可以从更简单的表达式自动生成。

def solve_gaussians(m1,s1,m2,s2):
  x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2)
  x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2)
  return x1,x2

加起来就是:

import numpy as np 
import matplotlib.pyplot as plt 
import scipy.stats

def plot_normal(x, mean = 0, sigma = 1):
  return scipy.stats.norm.pdf(x,loc=mean,scale=sigma)

#Use the equation from [this answer](https://stats.stackexchange.com/a/12213/12116) solved for x
def solve_gaussians(m1,s1,m2,s2):
  x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2)
  x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2)
  return x1,x2

s = np.linspace(0, 14,300)
x = solve_gaussians(5.0,0.5,7.0,1.0)

plt.figure('Baseline Distributions')
plt.title('Baseline Distributions')
plt.xlabel('Response Rate')
plt.ylabel('Probability')
plt.plot(s, plot_normal(s, 5.0, 0.5),'r', label='s1')
plt.plot(s, plot_normal(s, 7.0, 1.0),'b', label='s2')
plt.plot(x[0],plot_normal(x[0],5.,0.5),'mo')
plt.plot(x[1],plot_normal(x[1],5.,0.5),'mo')
plt.legend()
plt.show()

给予:

你在 plot_normal 函数中有一个小错误 - 你在分母中缺少平方根。正确版本:

def plot_normal(x, mean = 0, sigma = 1):
    return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))

给出了预期的结果:

还有两点说明。

  1. 请记住,您通常可以有 2 个方程的根(两个交点),您提供的参数就是这种情况。
  2. 据我所知,np.roots 给了你近似的结果,但你很容易得到准确的结果,重写 solve_gasussians 函数为:

    def solve_gasussians(m1, s1, m2, s2):
        # coefficients of quadratic equation ax^2 + bx + c = 0
        a = (s1**2.0) - (s2**2.0)
        b = 2 * (m1 * s2**2.0 - m2 * s1**2.0)
        c = m2**2.0 * s1**2.0 - m1**2.0 * s2**2.0 - 2 * s1**2.0 * s2**2.0 * np.log(s1/s2)
        x1 = (-b + np.sqrt(b**2.0 - 4.0 * a * c)) / (2.0 * a)
        x2 = (-b - np.sqrt(b**2.0 - 4.0 * a * c)) / (2.0 * a)
        return x1, x2