在 python 中生成截断的负二项分布
Generating truncated negative binomial distribution in python
我正在尝试生成遵循由数字组成的截断负二项分布的数据集,使得数字集具有最大值。
def truncated_Nbinom(n, p, max_value, size):
import scipy.stats as sct
temp_size = size
while True:
temp_size *= 2
temp = sct.nbinom.rvs(n, p, size=temp_size)
truncated = temp[temp <= max_value]
if len(truncated) >= size:
return truncated[:size]
当 max_value 和 n 较小时,我能够得到结果。但是,当我尝试使用时:
input_1= truncated_Nbinom(99, 0.3, 99, 5000).tolist()
内核不断死亡。我尝试更改 python 的端口并提高递归限制,但它们没有用。你有什么想法可以让我的代码更快吗?
这是一种方法。您可以计算 x
在负二项式下被选中的概率,然后将 x
低于 max_value
的概率归一化为一。现在,您可以简单地以适当的概率调用 np.random.choice
。
import numpy as np
import pandas as pd
from scipy import stats
def truncated_Nbinom2(n, p, max_value, size):
support = np.arange(max_value + 1)
probs = stats.nbinom.pmf(support, n, p)
probs /= probs.sum()
return np.random.choice(support, size=size, p=probs)
这是一个例子:
arr1 = truncated_Nbinom(9, 0.3, 9, 50000)
arr2 = truncated_Nbinom2(9, 0.3, 9, 50000)
df_counts = pd.DataFrame({
"version_1": pd.Series(arr1).value_counts(),
"version_2": pd.Series(arr2).value_counts(),
})
我正在尝试生成遵循由数字组成的截断负二项分布的数据集,使得数字集具有最大值。
def truncated_Nbinom(n, p, max_value, size):
import scipy.stats as sct
temp_size = size
while True:
temp_size *= 2
temp = sct.nbinom.rvs(n, p, size=temp_size)
truncated = temp[temp <= max_value]
if len(truncated) >= size:
return truncated[:size]
当 max_value 和 n 较小时,我能够得到结果。但是,当我尝试使用时:
input_1= truncated_Nbinom(99, 0.3, 99, 5000).tolist()
内核不断死亡。我尝试更改 python 的端口并提高递归限制,但它们没有用。你有什么想法可以让我的代码更快吗?
这是一种方法。您可以计算 x
在负二项式下被选中的概率,然后将 x
低于 max_value
的概率归一化为一。现在,您可以简单地以适当的概率调用 np.random.choice
。
import numpy as np
import pandas as pd
from scipy import stats
def truncated_Nbinom2(n, p, max_value, size):
support = np.arange(max_value + 1)
probs = stats.nbinom.pmf(support, n, p)
probs /= probs.sum()
return np.random.choice(support, size=size, p=probs)
这是一个例子:
arr1 = truncated_Nbinom(9, 0.3, 9, 50000)
arr2 = truncated_Nbinom2(9, 0.3, 9, 50000)
df_counts = pd.DataFrame({
"version_1": pd.Series(arr1).value_counts(),
"version_2": pd.Series(arr2).value_counts(),
})