Python Numba/jit 条件和递归(堆栈)使用

Python Numba/jit conditional and recursive (stack) use

全部,

我正在使用 numba JIT 来加速我的 Python 代码,但即使没有安装 numba 和 LLVM,代码也应该可以运行。

我的第一个想法是按如下方式进行:

use_numba = True
try:
    from numba import jit, int32
except ImportError, e:
    use_numba = False

def run_it(parameters):
    # do something
    pass

# define wrapper call function with optimizer
@jit
def run_it_with_numba(parameters):
    return run_it(parameters)

# [...]
# main program 
t_start = timeit.default_timer()

# this is the code I don't like 
if use_numba:
    res = run_it_with_numba(parameters)
else:
    res = run_it(parameters)

t_stop = timeit.default_timer()
print "Numba: ", use_numba, " Time: ", t_stop - t_start

这并没有像我预期的那样工作,因为编译似乎只适用于 run_it_with_numba() 函数——它基本上什么都不做——但不适用于从该函数调用的子例程。

当我在包含工作负载的函数上应用 @jit 时,结果只会变得更好。

是否有机会在主程序中避免包装函数和 if 子句?

有没有办法告诉 Numba 优化从我的入口函数调用的子例程?因为 run_it() 还包含一些函数调用,我希望 @jit 能够处理它。

铜, 麦酒

如果没有安装 Numba,您可以提供 jit 的无用版本:

use_numba = True
try:
    from numba import jit, int32
except ImportError, e:
    use_numba = False
    from _shim import jit, int32

@jit
def run_it(parameters):
    # do something
    pass

# [...]
# main program 
t_start = timeit.default_timer()

res = run_it(eval(row[0]), workfeed, instrument)

t_stop = timeit.default_timer()
print "Numba: ", use_numba, " Time: ", t_stop - t_start

其中 _shim.py 只包含:

def jit(*args, **kwargs):
    def wrapper(f):
        return f
    if len(args) > 0 and (args[0] is marker or not callable(args[0])) \
        or len(kwargs) > 0:
        # @jit(int32(int32, int32)), @jit(signature="void(int32)")
        return wrapper
    elif len(args) == 0:
        # @jit()
        return wrapper
    else:
        # @jit
        return args[0]

def marker(*args, **kwargs): return marker

int32 = marker

我想你想用不同的方式来做这件事。不是包装方法,而是可选地为其设置别名。例如,使用虚拟方法允许实际计时:

import numpy as np
import timeit 

use_numba = False
try:
    import numba as nb
except ImportError, e:
    use_numba = False

def _run_it(a, N):
    s = 0.0
    for k in xrange(N):
        s += k / np.sin(a)

    return s

# define wrapper call function with optimizer
if use_numba:
    print 'Using numba'
    run_it = nb.jit()(_run_it)
else:
    print 'Falling back to python'
    run_it = _run_it

if __name__ == '__main__':
    print timeit.repeat('run_it(50.0, 100000)', setup='from __main__ import run_it', repeat=3, number=100)

运行 这与 use_numba 标志为 True:

$ python nbtest.py
Using numba
[0.18746304512023926, 0.15185213088989258, 0.1636970043182373]

False:

$ python nbtest.py
Falling back to python
[9.707707166671753, 9.779848098754883, 9.770231008529663]

或者在 iPython 笔记本中使用漂亮的 %timeit 魔法:

run_it_numba = nb.jit()(_run_it)

%timeit _run_it(50.0, 10000)
100 loops, best of 3: 9.51 ms per loop

%timeit run_it_numba(50.0, 10000)  
10000 loops, best of 3: 144 µs per loop

请注意,在为 numba 方法计时时,为该方法的单次执行计时将考虑 numba jit 方法所花费的时间。所有后续运行都会快得多。