高斯之间的交集
Intersection between Gaussian
我只是想绘制两个高斯曲线并找到交点。我有以下代码。虽然它没有绘制确切的交叉点,但我真的不明白为什么。好像只是稍微偏离了一点,但是如果我们对减去的高斯函数进行对数,我就完成了派生的解决方案,是的,它似乎应该是正确的。谁能帮忙?非常感谢!
import numpy as np
import matplotlib.pyplot as plt
def plot_normal(x, mean = 0, sigma = 1):
return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))
# found online
def solve_gasussians(m1, s1, m2, s2):
a = 1.0/(2.0*s1**2) - 1.0/(2.0*s2**2)
b = m2/(s2**2) - m1/(s1**2)
c = m1**2 /(2*s1**2) - m2**2 / (2.0*s2**2) - np.log(s2/s1)
return np.roots([a,b,c])
s1 = np.linspace(0, 10,300)
s2 = np.linspace(0, 14, 300)
solved_val = solve_gasussians(5.0, 0.5, 7.0, 1.0)
print solved_val
solved_val = solved_val[0]
plt.figure('Baseline Distributions')
plt.title('Baseline Distributions')
plt.xlabel('Response Rate')
plt.ylabel('Probability')
plt.plot(s1, plot_normal(s1, 5.0, 0.5),'r', label='s1')
plt.plot(s2, plot_normal(s2, 7.0, 1.0),'b', label='s2')
plt.plot(solved_val, plot_normal(solved_val, 7.0, 1.0), 'mo')
plt.legend()
plt.show()
不知道你的代码错在哪里。但我想我找到了你借用的代码并做了你需要的部分调整。
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
def solve(m1,m2,std1,std2):
a = 1/(2*std1**2) - 1/(2*std2**2)
b = m2/(std2**2) - m1/(std1**2)
c = m1**2 /(2*std1**2) - m2**2 / (2*std2**2) - np.log(std2/std1)
return np.roots([a,b,c])
m1 = 5
std1 = 0.5
m2 = 7
std2 = 1
result = solve(m1,m2,std1,std2)
x = np.linspace(-5,9,10000)
plot1=plt.plot(x,[norm.pdf(_,m1,std1) for _ in x])
plot2=plt.plot(x,[norm.pdf(_,m2,std2) for _ in x])
plot3=plt.plot(result[0],norm.pdf(result[0],m1,std1) ,'o')
plt.show()
我将主动提供两条建议,它们可能会让您的生活更轻松(就像他们为我做的那样):
- 当您调整代码时,尝试进行小的增量更改,并检查代码在每一步是否仍然有效。
- 查找现有的免费图书馆。在这种情况下,scipy 中的 norm 可以很好地替代原始代码中使用的内容。
错误就在这里。这一行:
def plot_normal(x, mean = 0, sigma = 1):
return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))
应该是这样的:
def plot_normal(x, mean = 0, sigma = 1):
return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))
您忘记了sqrt
。
如果可用的话,最好使用已有的普通 pdf,例如:
import scipy.stats
def plot_normal(x, mean = 0, sigma = 1):
return scipy.stats.norm.pdf(x,loc=mean,scale=sigma)
也可以精确求解交叉点。 This answer 提供高斯交叉点的根的二次方程。使用最大值求解 x 给出以下表达式。其中,虽然复杂,但不依赖于迭代方法,可以从更简单的表达式自动生成。
def solve_gaussians(m1,s1,m2,s2):
x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2)
x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2)
return x1,x2
加起来就是:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
def plot_normal(x, mean = 0, sigma = 1):
return scipy.stats.norm.pdf(x,loc=mean,scale=sigma)
#Use the equation from [this answer](https://stats.stackexchange.com/a/12213/12116) solved for x
def solve_gaussians(m1,s1,m2,s2):
x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2)
x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2)
return x1,x2
s = np.linspace(0, 14,300)
x = solve_gaussians(5.0,0.5,7.0,1.0)
plt.figure('Baseline Distributions')
plt.title('Baseline Distributions')
plt.xlabel('Response Rate')
plt.ylabel('Probability')
plt.plot(s, plot_normal(s, 5.0, 0.5),'r', label='s1')
plt.plot(s, plot_normal(s, 7.0, 1.0),'b', label='s2')
plt.plot(x[0],plot_normal(x[0],5.,0.5),'mo')
plt.plot(x[1],plot_normal(x[1],5.,0.5),'mo')
plt.legend()
plt.show()
给予:
你在 plot_normal
函数中有一个小错误 - 你在分母中缺少平方根。正确版本:
def plot_normal(x, mean = 0, sigma = 1):
return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))
给出了预期的结果:
还有两点说明。
- 请记住,您通常可以有 2 个方程的根(两个交点),您提供的参数就是这种情况。
据我所知,np.roots
给了你近似的结果,但你很容易得到准确的结果,重写 solve_gasussians
函数为:
def solve_gasussians(m1, s1, m2, s2):
# coefficients of quadratic equation ax^2 + bx + c = 0
a = (s1**2.0) - (s2**2.0)
b = 2 * (m1 * s2**2.0 - m2 * s1**2.0)
c = m2**2.0 * s1**2.0 - m1**2.0 * s2**2.0 - 2 * s1**2.0 * s2**2.0 * np.log(s1/s2)
x1 = (-b + np.sqrt(b**2.0 - 4.0 * a * c)) / (2.0 * a)
x2 = (-b - np.sqrt(b**2.0 - 4.0 * a * c)) / (2.0 * a)
return x1, x2
我只是想绘制两个高斯曲线并找到交点。我有以下代码。虽然它没有绘制确切的交叉点,但我真的不明白为什么。好像只是稍微偏离了一点,但是如果我们对减去的高斯函数进行对数,我就完成了派生的解决方案,是的,它似乎应该是正确的。谁能帮忙?非常感谢!
import numpy as np
import matplotlib.pyplot as plt
def plot_normal(x, mean = 0, sigma = 1):
return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))
# found online
def solve_gasussians(m1, s1, m2, s2):
a = 1.0/(2.0*s1**2) - 1.0/(2.0*s2**2)
b = m2/(s2**2) - m1/(s1**2)
c = m1**2 /(2*s1**2) - m2**2 / (2.0*s2**2) - np.log(s2/s1)
return np.roots([a,b,c])
s1 = np.linspace(0, 10,300)
s2 = np.linspace(0, 14, 300)
solved_val = solve_gasussians(5.0, 0.5, 7.0, 1.0)
print solved_val
solved_val = solved_val[0]
plt.figure('Baseline Distributions')
plt.title('Baseline Distributions')
plt.xlabel('Response Rate')
plt.ylabel('Probability')
plt.plot(s1, plot_normal(s1, 5.0, 0.5),'r', label='s1')
plt.plot(s2, plot_normal(s2, 7.0, 1.0),'b', label='s2')
plt.plot(solved_val, plot_normal(solved_val, 7.0, 1.0), 'mo')
plt.legend()
plt.show()
不知道你的代码错在哪里。但我想我找到了你借用的代码并做了你需要的部分调整。
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
def solve(m1,m2,std1,std2):
a = 1/(2*std1**2) - 1/(2*std2**2)
b = m2/(std2**2) - m1/(std1**2)
c = m1**2 /(2*std1**2) - m2**2 / (2*std2**2) - np.log(std2/std1)
return np.roots([a,b,c])
m1 = 5
std1 = 0.5
m2 = 7
std2 = 1
result = solve(m1,m2,std1,std2)
x = np.linspace(-5,9,10000)
plot1=plt.plot(x,[norm.pdf(_,m1,std1) for _ in x])
plot2=plt.plot(x,[norm.pdf(_,m2,std2) for _ in x])
plot3=plt.plot(result[0],norm.pdf(result[0],m1,std1) ,'o')
plt.show()
我将主动提供两条建议,它们可能会让您的生活更轻松(就像他们为我做的那样):
- 当您调整代码时,尝试进行小的增量更改,并检查代码在每一步是否仍然有效。
- 查找现有的免费图书馆。在这种情况下,scipy 中的 norm 可以很好地替代原始代码中使用的内容。
错误就在这里。这一行:
def plot_normal(x, mean = 0, sigma = 1):
return 1.0/(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))
应该是这样的:
def plot_normal(x, mean = 0, sigma = 1):
return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))
您忘记了sqrt
。
如果可用的话,最好使用已有的普通 pdf,例如:
import scipy.stats
def plot_normal(x, mean = 0, sigma = 1):
return scipy.stats.norm.pdf(x,loc=mean,scale=sigma)
也可以精确求解交叉点。 This answer 提供高斯交叉点的根的二次方程。使用最大值求解 x 给出以下表达式。其中,虽然复杂,但不依赖于迭代方法,可以从更简单的表达式自动生成。
def solve_gaussians(m1,s1,m2,s2):
x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2)
x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2)
return x1,x2
加起来就是:
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
def plot_normal(x, mean = 0, sigma = 1):
return scipy.stats.norm.pdf(x,loc=mean,scale=sigma)
#Use the equation from [this answer](https://stats.stackexchange.com/a/12213/12116) solved for x
def solve_gaussians(m1,s1,m2,s2):
x1 = (s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)+m1*s2**2-m2*s1**2)/(s2**2-s1**2)
x2 = -(s1*s2*np.sqrt((-2*np.log(s1/s2)*s2**2)+2*s1**2*np.log(s1/s2)+m2**2-2*m1*m2+m1**2)-m1*s2**2+m2*s1**2)/(s2**2-s1**2)
return x1,x2
s = np.linspace(0, 14,300)
x = solve_gaussians(5.0,0.5,7.0,1.0)
plt.figure('Baseline Distributions')
plt.title('Baseline Distributions')
plt.xlabel('Response Rate')
plt.ylabel('Probability')
plt.plot(s, plot_normal(s, 5.0, 0.5),'r', label='s1')
plt.plot(s, plot_normal(s, 7.0, 1.0),'b', label='s2')
plt.plot(x[0],plot_normal(x[0],5.,0.5),'mo')
plt.plot(x[1],plot_normal(x[1],5.,0.5),'mo')
plt.legend()
plt.show()
给予:
你在 plot_normal
函数中有一个小错误 - 你在分母中缺少平方根。正确版本:
def plot_normal(x, mean = 0, sigma = 1):
return 1.0/np.sqrt(2*np.pi*sigma**2) * np.exp(-((x-mean)**2)/(2*sigma**2))
给出了预期的结果:
还有两点说明。
- 请记住,您通常可以有 2 个方程的根(两个交点),您提供的参数就是这种情况。
据我所知,
np.roots
给了你近似的结果,但你很容易得到准确的结果,重写solve_gasussians
函数为:def solve_gasussians(m1, s1, m2, s2): # coefficients of quadratic equation ax^2 + bx + c = 0 a = (s1**2.0) - (s2**2.0) b = 2 * (m1 * s2**2.0 - m2 * s1**2.0) c = m2**2.0 * s1**2.0 - m1**2.0 * s2**2.0 - 2 * s1**2.0 * s2**2.0 * np.log(s1/s2) x1 = (-b + np.sqrt(b**2.0 - 4.0 * a * c)) / (2.0 * a) x2 = (-b - np.sqrt(b**2.0 - 4.0 * a * c)) / (2.0 * a) return x1, x2