可选地使用 jit 将参数传递给另一个函数
Optionally passing parameters onto another function with jit
我正在尝试对 python 函数进行 jit 编译,并使用可选参数来更改另一个函数调用的参数。
我认为 jit 可能出错的地方是可选参数的默认值是 None,而 jit 不知道如何处理它,或者至少不知道如何处理当它变成一个 numpy 数组时。请参阅下面的粗略概述:
@jit(nopython=True)
def foo(otherFunc,arg1, optionalArg=None):
if optionalArg is not None:
out=otherFunc(arg1,optionalArg)
else:
out=otherFunc(arg1)
return out
其中 optionalArg 是 None 或 numpy 数组
一个解决方案是将其转换为如下所示的三个函数,但这感觉有点笨拙,我不喜欢它,尤其是因为速度对于这项任务非常重要。
def foo(otherFunc,arg1,optionalArg=None):
if optionalArg is not None:
out=func1(otherFunc,arg1,optionalArg)
else:
out=func2(otherFunc,arg1)
return out
@jit(nopython=True)
def func1(otherFunc,arg1,optionalArg):
out=otherFunc(arg1,optionalArg)
return out
@jit(nopython=True)
def func2(otherFunc,arg1):
out=otherFunc(arg1)
return out
请注意,除了调用 otherFunc 之外还有其他事情正在发生,这使得使用 jit 是值得的,但我几乎可以肯定这不是问题所在,因为这在没有 optionalArg 部分的情况下工作,所以我决定不包括它。
对于那些好奇的人来说,它的 运行ge-kutta order 4 实现带有可选的额外参数以传递给微分方程。如果你想看到整个事情就问。
回溯相当长,但这里是其中的一部分:
inte.rk4(de2,y0,0.001,200,vals=np.ones(4))
Traceback (most recent call last):
File "<ipython-input-38-478197aa6a1a>", line 1, in <module>
inte.rk4(de2,y0,0.001,200,vals=np.ones(4))
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E168C358>:
This continues...
inte.rk4 是 foo 的等价物,de2 是 otherFunc,y0、0.001 和 200 只是值,我在上面的问题描述中换成了 arg1,而 vals 是 optionalArg。
当我在省略 vals 参数的情况下尝试 运行 时会发生类似的事情:
ysExp=inte.rk4(deExp,y0,0.001,200)
Traceback (most recent call last):
File "<ipython-input-39-7dde4bcbdc2f>", line 1, in <module>
ysExp=inte.rk4(deExp,y0,0.001,200)
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E048EA90>:
This continues...
如果您看到文档 here,您可以在 Numba 中明确指定 optional
类型参数。例如(这与文档中的示例相同):
>>> @jit((optional(intp),))
... def f(x):
... return x is not None
...
>>> f(0)
True
>>> f(None)
False
此外,根据正在进行的对话 this Github issue,您可以使用以下解决方法来实施可选关键字。我修改了 github 问题中提供的解决方案中的代码以适合您的示例:
from numba import jitclass, int32, njit
from collections import OrderedDict
import numpy as np
np_arr = np.asarray([1,2])
spec = OrderedDict()
spec['x'] = int32
@jitclass(spec)
class Foo(object):
def __init__(self, x):
self.x = x
def otherFunc(self, optionalArg):
if optionalArg is None:
return self.x + 10
else:
return len(optionalArg)
@njit
def useOtherFunc(arg1, optArg):
foo = Foo(arg1)
print(foo.otherFunc(optArg))
arg1 = 5
useOtherFunc(arg1, np_arr) # Output: 2
useOtherFunc(arg1, None) # Output : 15
有关上面显示的示例,请参阅 this colab notebook。
我正在尝试对 python 函数进行 jit 编译,并使用可选参数来更改另一个函数调用的参数。
我认为 jit 可能出错的地方是可选参数的默认值是 None,而 jit 不知道如何处理它,或者至少不知道如何处理当它变成一个 numpy 数组时。请参阅下面的粗略概述:
@jit(nopython=True)
def foo(otherFunc,arg1, optionalArg=None):
if optionalArg is not None:
out=otherFunc(arg1,optionalArg)
else:
out=otherFunc(arg1)
return out
其中 optionalArg 是 None 或 numpy 数组
一个解决方案是将其转换为如下所示的三个函数,但这感觉有点笨拙,我不喜欢它,尤其是因为速度对于这项任务非常重要。
def foo(otherFunc,arg1,optionalArg=None):
if optionalArg is not None:
out=func1(otherFunc,arg1,optionalArg)
else:
out=func2(otherFunc,arg1)
return out
@jit(nopython=True)
def func1(otherFunc,arg1,optionalArg):
out=otherFunc(arg1,optionalArg)
return out
@jit(nopython=True)
def func2(otherFunc,arg1):
out=otherFunc(arg1)
return out
请注意,除了调用 otherFunc 之外还有其他事情正在发生,这使得使用 jit 是值得的,但我几乎可以肯定这不是问题所在,因为这在没有 optionalArg 部分的情况下工作,所以我决定不包括它。
对于那些好奇的人来说,它的 运行ge-kutta order 4 实现带有可选的额外参数以传递给微分方程。如果你想看到整个事情就问。
回溯相当长,但这里是其中的一部分:
inte.rk4(de2,y0,0.001,200,vals=np.ones(4))
Traceback (most recent call last):
File "<ipython-input-38-478197aa6a1a>", line 1, in <module>
inte.rk4(de2,y0,0.001,200,vals=np.ones(4))
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E168C358>:
This continues...
inte.rk4 是 foo 的等价物,de2 是 otherFunc,y0、0.001 和 200 只是值,我在上面的问题描述中换成了 arg1,而 vals 是 optionalArg。
当我在省略 vals 参数的情况下尝试 运行 时会发生类似的事情:
ysExp=inte.rk4(deExp,y0,0.001,200)
Traceback (most recent call last):
File "<ipython-input-39-7dde4bcbdc2f>", line 1, in <module>
ysExp=inte.rk4(deExp,y0,0.001,200)
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in _compile_for_args
error_rewrite(e, 'typing')
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\dispatcher.py", line 317, in error_rewrite
reraise(type(e), e, None)
File "C:\Users\Alex\Anaconda3\lib\site-packages\numba\six.py", line 658, in reraise
raise value.with_traceback(tb)
TypingError: Internal error at <numba.typeinfer.CallConstraint object at 0x00000258E048EA90>:
This continues...
如果您看到文档 here,您可以在 Numba 中明确指定 optional
类型参数。例如(这与文档中的示例相同):
>>> @jit((optional(intp),))
... def f(x):
... return x is not None
...
>>> f(0)
True
>>> f(None)
False
此外,根据正在进行的对话 this Github issue,您可以使用以下解决方法来实施可选关键字。我修改了 github 问题中提供的解决方案中的代码以适合您的示例:
from numba import jitclass, int32, njit
from collections import OrderedDict
import numpy as np
np_arr = np.asarray([1,2])
spec = OrderedDict()
spec['x'] = int32
@jitclass(spec)
class Foo(object):
def __init__(self, x):
self.x = x
def otherFunc(self, optionalArg):
if optionalArg is None:
return self.x + 10
else:
return len(optionalArg)
@njit
def useOtherFunc(arg1, optArg):
foo = Foo(arg1)
print(foo.otherFunc(optArg))
arg1 = 5
useOtherFunc(arg1, np_arr) # Output: 2
useOtherFunc(arg1, None) # Output : 15
有关上面显示的示例,请参阅 this colab notebook。