带有函数输入的 Numba jitclass

Numba jitclass with function input

我正致力于在 Numba 中开发自适应拒绝采样器。我想使用 class 来实现它,因为我认为它会使代码更清晰,而且我看到 Numba 支持 classes。如果我的 class 可以将函数作为输入,即我想从中采样的分布的日志 pdf,我的 general/useful 会多得多。有什么办法吗?我想另一种方法是在 class 定义本身中定义对数 pdf 方程。

我为什么要这样做?采样器将用作 Gibbs 采样方案的一部分,因此每个采样步骤的加速至关重要。我必须从我只知道一个归一化常数的分布进行模拟,自适应拒绝采样是一种通用技术,可以帮助我在不需要知道这个归一化常数的情况下进行采样。有一个自适应拒绝采样器的 python 实现围绕堆栈溢出浮动,但它对我的目的来说太慢了。它还会由于某种原因在它应该处理的一些模拟数据上随机中断。我在项目的其他部分使用 numba 很幸运,包括在 Gibbs 采样器的一部分上实现了超过 100 倍的加速。

Numba 函数不能将函数作为输入参数。官方文档建议在某些情况下可能在函数工厂中使用闭包作为解决方法:

http://numba.pydata.org/numba-doc/latest/user/faq.html#can-i-pass-a-function-as-an-argument-to-a-jitted-function

复制上面的代码示例 link 以防 url 变得无效:

def make_f(g):
    # Note: a new f() is compiled each time make_f() is called!
    @jit(nopython=True)
    def f(x):
        return g(x) + g(-x)
    return f

f = make_f(my_g_function)
result = f(1)

不确定这是否适用于您的特定情况。我认为将您想要的函数定义为 class 方法会是一个更好的策略,尽管没有代码示例,我只是在猜测。