可选地使用 jit 将参数传递给另一个函数

Optionally passing parameters onto another function with jit

我正在尝试对 python 函数进行 jit 编译,并使用可选参数来更改另一个函数调用的参数。

我认为 jit 可能出错的地方是可选参数的默认值是 None,而 jit 不知道如何处理它,或者至少不知道如何处理当它变成一个 numpy 数组时。请参阅下面的粗略概述:

@jit(nopython=True)
def foo(otherFunc,arg1, optionalArg=None):

    if optionalArg is not None:
        out=otherFunc(arg1,optionalArg)

    else:
        out=otherFunc(arg1)
    return out

其中 optionalArg 是 None 或 numpy 数组

一个解决方案是将其转换为如下所示的三个函数,但这感觉有点笨拙,我不喜欢它,尤其是因为速度对于这项任务非常重要。

def foo(otherFunc,arg1,optionalArg=None):

    if optionalArg is not None:
        out=func1(otherFunc,arg1,optionalArg)
    else:
        out=func2(otherFunc,arg1)
    return out

@jit(nopython=True)
def func1(otherFunc,arg1,optionalArg):
    out=otherFunc(arg1,optionalArg)
    return out

@jit(nopython=True)
def func2(otherFunc,arg1):
    out=otherFunc(arg1)
    return out

请注意,除了调用 otherFunc 之外还有其他事情正在发生,这使得使用 jit 是值得的,但我几乎可以肯定这不是问题所在,因为这在没有 optionalArg 部分的情况下工作,所以我决定不包括它。

对于那些好奇的人来说,它的 运行ge-kutta order 4 实现带有可选的额外参数以传递给微分方程。如果你想看到整个事情就问。

回溯相当长,但这里是其中的一部分:

inte.rk4(de2,y0,0.001,200,vals=np.ones(4))
Traceback (most recent call last):

  File "<ipython-input-38-478197aa6a1a>", line 1, in <module>
    inte.rk4(de2,y0,0.001,200,vals=np.ones(4))

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
    error_rewrite(e, 'typing')

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
    reraise(type(e), e, None)

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
    raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E168C358>:

This continues...

inte.rk4 是 foo 的等价物,de2 是 otherFunc,y0、0.001 和 200 只是值,我在上面的问题描述中换成了 arg1,而 vals 是 optionalArg。

当我在省略 vals 参数的情况下尝试 运行 时会发生类似的事情:

ysExp=inte.rk4(deExp,y0,0.001,200)
Traceback (most recent call last):

  File "<ipython-input-39-7dde4bcbdc2f>", line 1, in <module>
    ysExp=inte.rk4(deExp,y0,0.001,200)

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
    error_rewrite(e, 'typing')

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
    reraise(type(e), e, None)

  File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
    raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E048EA90>:

This continues...

如果您看到文档 here,您可以在 Numba 中明确指定 optional 类型参数。例如(这与文档中的示例相同):

>>> @jit((optional(intp),))
... def f(x):
...     return x is not None
...
>>> f(0)
True
>>> f(None)
False

此外,根据正在进行的对话 this Github issue,您可以使用以下解决方法来实施可选关键字。我修改了 github 问题中提供的解决方案中的代码以适合您的示例:

from numba import jitclass, int32, njit
from collections import OrderedDict
import numpy as np

np_arr = np.asarray([1,2])

spec = OrderedDict()
spec['x'] = int32

@jitclass(spec)
class Foo(object):
    def __init__(self, x):
        self.x = x

    def otherFunc(self, optionalArg):
        if optionalArg is None:
            return self.x + 10
        else:
            return len(optionalArg)
@njit
def useOtherFunc(arg1, optArg):
    foo = Foo(arg1)

    print(foo.otherFunc(optArg))

arg1 = 5

useOtherFunc(arg1, np_arr)   # Output: 2
useOtherFunc(arg1, None)     # Output : 15

有关上面显示的示例,请参阅 this colab notebook