使用 scipy.stats.rv_continuous 对分布进行子类化时出现溢出错误
Overflow error subclassing a distribution using scipy.stats.rv_continuous
In the documentation page of rv_continuous 我们可以找到一个 'custom' 高斯分布 class 如下。
from scipy.stats import rv_continuous
import numpy as np
class gaussian_gen(rv_continuous):
"Gaussian distribution"
def _pdf(self, x):
return np.exp(-x**2 / 2.) / np.sqrt(2.0 * np.pi)
gaussian = gaussian_gen(name='gaussian')
反过来,我尝试创建一个 class 以 2 为底的指数分布,以模拟一些核衰变:
class time_dist(rv_continuous):
def _pdf(self, x):
return 2**(-x)
random_var = time_dist(name = 'decay')
这具有调用 random_var.rvs()
的目的,以便根据我定义的 pdf 生成随机分布的值样本。然而,当我 运行 这样做时,我收到一个 OverflowError,我不太明白为什么。最初我认为这与函数未规范化有关。但是,我一直在更改 _pdf 定义,但无济于事。代码有什么问题还是这种方法不适合定义此类函数?
根据维基百科,pdf of an exponential distribution 将是:
lambda * exp(-lambda*x)
对于 x >= 0
0
对于 x < 0
所以,大概函数应该改成如下:
from scipy.stats import rv_continuous
import numpy as np
import matplotlib.pyplot as plt
class time_dist(rv_continuous):
def _pdf(self, x):
return np.log(2) * 2 ** (-x) if x >= 0 else 0
random_var = time_dist(name='decay')
plt.hist(random_var.rvs(size=500))
plt.show()
In the documentation page of rv_continuous 我们可以找到一个 'custom' 高斯分布 class 如下。
from scipy.stats import rv_continuous
import numpy as np
class gaussian_gen(rv_continuous):
"Gaussian distribution"
def _pdf(self, x):
return np.exp(-x**2 / 2.) / np.sqrt(2.0 * np.pi)
gaussian = gaussian_gen(name='gaussian')
反过来,我尝试创建一个 class 以 2 为底的指数分布,以模拟一些核衰变:
class time_dist(rv_continuous):
def _pdf(self, x):
return 2**(-x)
random_var = time_dist(name = 'decay')
这具有调用 random_var.rvs()
的目的,以便根据我定义的 pdf 生成随机分布的值样本。然而,当我 运行 这样做时,我收到一个 OverflowError,我不太明白为什么。最初我认为这与函数未规范化有关。但是,我一直在更改 _pdf 定义,但无济于事。代码有什么问题还是这种方法不适合定义此类函数?
根据维基百科,pdf of an exponential distribution 将是:
lambda * exp(-lambda*x)
对于x >= 0
0
对于x < 0
所以,大概函数应该改成如下:
from scipy.stats import rv_continuous
import numpy as np
import matplotlib.pyplot as plt
class time_dist(rv_continuous):
def _pdf(self, x):
return np.log(2) * 2 ** (-x) if x >= 0 else 0
random_var = time_dist(name='decay')
plt.hist(random_var.rvs(size=500))
plt.show()