一种非常快速的近似np.random.dirichlet大维度的方法
A very quick method to approximate np.random.dirichlet with large dimension
我想尽快用大尺寸评估 np.random.dirichlet。更准确地说,我想要一个函数以至少 10 倍的速度逼近下面的函数。根据经验,我观察到此函数的小维版本输出一个或两个具有 0.1 阶的条目,并且每个其他条目都非常小以至于它们无关紧要。但这一观察并非基于任何严格的评估。近似不需要那么准确,但我想要一些不太粗糙的东西,因为我正在为 MCTS 使用这种噪声。
def g():
np.random.dirichlet([0.03]*4840)
>>> timeit.timeit(g,number=1000)
0.35117408499991143
假设您的 alpha 在组件上固定并用于多次迭代,您可以将相应伽马分布的 ppf 制表。这可能可用 scipy.stats.gamma.ppf
,但我们也可以使用 scipy.special.gammaincinv
。这个功能似乎比较慢,所以这是一笔巨大的前期投资。
这里是总体思路的粗略实现:
import numpy as np
from scipy import special
class symm_dirichlet:
def __init__(self, alpha, resolution=2**16):
self.alpha = alpha
self.resolution = resolution
self.range, delta = np.linspace(0, 1, resolution,
endpoint=False, retstep=True)
self.range += delta / 2
self.table = special.gammaincinv(self.alpha, self.range)
def draw(self, n_sampl, n_comp, interp='nearest'):
if interp != 'nearest':
raise NotImplementedError
gamma = self.table[np.random.randint(0, self.resolution,
(n_sampl, n_comp))]
return gamma / gamma.sum(axis=1, keepdims=True)
import time, timeit
t0 = time.perf_counter()
X = symm_dirichlet(0.03)
t1 = time.perf_counter()
print(f'Upfront cost {t1-t0:.3f} sec')
print('Running cost per 1000 samples of width 4840')
print('tabulated {:3f} sec'.format(timeit.timeit(
'X.draw(1, 4840)', number=1000, globals=globals())))
print('np.random.dirichlet {:3f} sec'.format(timeit.timeit(
'np.random.dirichlet([0.03]*4840)', number=1000, globals=globals())))
示例输出:
Upfront cost 13.067 sec
Running cost per 1000 samples of width 4840
tabulated 0.059365 sec
np.random.dirichlet 0.980067 sec
最好检查一下是否大致正确:
我想尽快用大尺寸评估 np.random.dirichlet。更准确地说,我想要一个函数以至少 10 倍的速度逼近下面的函数。根据经验,我观察到此函数的小维版本输出一个或两个具有 0.1 阶的条目,并且每个其他条目都非常小以至于它们无关紧要。但这一观察并非基于任何严格的评估。近似不需要那么准确,但我想要一些不太粗糙的东西,因为我正在为 MCTS 使用这种噪声。
def g():
np.random.dirichlet([0.03]*4840)
>>> timeit.timeit(g,number=1000)
0.35117408499991143
假设您的 alpha 在组件上固定并用于多次迭代,您可以将相应伽马分布的 ppf 制表。这可能可用 scipy.stats.gamma.ppf
,但我们也可以使用 scipy.special.gammaincinv
。这个功能似乎比较慢,所以这是一笔巨大的前期投资。
这里是总体思路的粗略实现:
import numpy as np
from scipy import special
class symm_dirichlet:
def __init__(self, alpha, resolution=2**16):
self.alpha = alpha
self.resolution = resolution
self.range, delta = np.linspace(0, 1, resolution,
endpoint=False, retstep=True)
self.range += delta / 2
self.table = special.gammaincinv(self.alpha, self.range)
def draw(self, n_sampl, n_comp, interp='nearest'):
if interp != 'nearest':
raise NotImplementedError
gamma = self.table[np.random.randint(0, self.resolution,
(n_sampl, n_comp))]
return gamma / gamma.sum(axis=1, keepdims=True)
import time, timeit
t0 = time.perf_counter()
X = symm_dirichlet(0.03)
t1 = time.perf_counter()
print(f'Upfront cost {t1-t0:.3f} sec')
print('Running cost per 1000 samples of width 4840')
print('tabulated {:3f} sec'.format(timeit.timeit(
'X.draw(1, 4840)', number=1000, globals=globals())))
print('np.random.dirichlet {:3f} sec'.format(timeit.timeit(
'np.random.dirichlet([0.03]*4840)', number=1000, globals=globals())))
示例输出:
Upfront cost 13.067 sec
Running cost per 1000 samples of width 4840
tabulated 0.059365 sec
np.random.dirichlet 0.980067 sec
最好检查一下是否大致正确: