Numba 无法在 nopython 模式下编译基于 np.select 的函数
Numba fails to compile np.select based function in nopython mode
我正在尝试使用 numba.njit
编译有效的分段函数。 Python函数定义如下,使用Numpy:
(任何对此问题的 Sympy
起源感兴趣的人,请参阅下面的注释。)
最小示例
from numpy import select, less, nan
def f(t):
condlist = [less(t, 5), less(t, 15), less(t, 20), True]
choicelist = [1, 0, 1, 0]
return select(condlist, choicelist, default=nan)
请参阅下文以确认此功能在 Python 中有效。
问题:但是,Numba 无法在 nopython
模式下 JIT 此函数:
from numba import njit
jit_f = njit(f)
x = np.linspace(0, 50, 500)
jit_f(x)
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
Input In [86], in <cell line: 5>()
2 jit_f = njit(f)
4 x = np.linspace(0,50,500)
----> 5 jit_f(x)
File /usr/local/lib/python3.8/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
464 msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
465 f"by the following argument(s):\n{args_str}\n")
466 e.patch_message(msg)
--> 468 error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
471 error_rewrite(e, 'unsupported_error')
File /usr/local/lib/python3.8/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
407 raise e
408 else:
--> 409 raise e.with_traceback(None)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function select at 0x105f4d310>) found for signature:
>>> select(LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True))), list(int64)<iv=[1, 0, 1, 0]>, default=float64)
There are 2 candidate implementations:
- Of which 1 did not match due to:
Overload in function 'np_select': File: numba/np/arraymath.py: Line 4358.
With argument(s): '(Poison<LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True)))>, list(int64)<iv=None>, default=float64)':
Rejected as the implementation raised a specific error:
TypingError: Poison type used in arguments; got Poison<LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True)))>
raised from /usr/local/lib/python3.8/site-packages/numba/core/types/functions.py:236
- Of which 1 did not match due to:
Overload in function 'np_select': File: numba/np/arraymath.py: Line 4358.
With argument(s): '(LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True))), list(int64)<iv=[1, 0, 1, 0]>, default=float64)':
Rejected as the implementation raised a specific error:
NumbaTypeError: condlist must be a List or a Tuple
raised from /usr/local/lib/python3.8/site-packages/numba/np/arraymath.py:4375
During: resolving callee type: Function(<function select at 0x105f4d310>)
During: typing of call at /var/folders/dt/q6vbs0g56s70g4p2kyfj4tvh0000gn/T/ipykernel_61924/130570246.py (5)
File "../../../../../../../../../var/folders/dt/q6vbs0g56s70g4p2kyfj4tvh0000gn/T/ipykernel_61924/130570246.py", line 5:
<source missing, REPL/exec in use?>
我不是 Numba 专家,但我的感觉是存在一些语法错误。我玩过传递 Numpy 数组和 condlist
和 choicelist
的不同格式,但到目前为止没有运气。
其他注意事项
Python 函数的行为符合预期,在本例中给出了一些二元振荡然后为零:
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 50, 500)
plt.plot(x, f(x))
对于任何 Sympy 爱好者来说,这里最重要的问题是使用 Numba JIT 编译通过 Sympy 从 sympy.Piecewise
生成的 lambda。与上例中的 f(t)
非常相似的 lambda 可以由 sympy.lambdify
在分段函数上自动生成。
Numba 目前并未实现所有 Numpy 功能,有时支持有限。您可以找到支持的函数列表 in the documentation。对于 np.select
,文档指出支持仅限于:
only using homogeneous lists or tuples for the first two arguments, condlist
and choicelist
. Additionally, these two arguments can only contain arrays (unlike Numpy that also accepts tuples).
问题是 condlist
不是同构的,因为列表的前 3 项是数组,而最后一项是布尔值。此外,choicelist
包含整数,而它必须包含数组。
解决此问题的一种方法是使用以下代码:
def f(t):
condlist = [less(t, 5), less(t, 15), less(t, 20), np.full(t.size, True)]
all_zeros = np.zeros(t.size)
all_ones = np.ones(t.size)
choicelist = [all_ones, all_zeros, all_ones, all_zeros]
return select(condlist, choicelist, default=nan)
但是,请不要使用此代码,因为它效率低下。实际上,它创建了许多 临时数组,创建和填充速度很慢。代码肯定会受内存限制,而内存是一种稀缺资源,在过去几十年里才慢慢得到改善(这被称为 “内存墙”)。优化此类代码很难,Numba 在这方面并不比 Numpy 快。事实上,Numpy 已经非常有效地做到这一点,因为它是用 C 语言实现的,并且大多数功能都经过精心优化。当您使用循环并避免创建(无用的)临时数组时,Numba 速度很快。简而言之:Numba 喜欢循环 而不是 Numpy。这是一个更快的解决方案:
def f(t):
result = np.empty(t.size)
for i in range(t.size):
result[i] = t[i] < 5 or 15 <= t[i] < 20
return result
请注意,使用布尔值或短整数(例如 int8
)输出类型应该更快(floating-point 数字的计算速度很慢并且需要大量 space在内存中)。
我正在尝试使用 numba.njit
编译有效的分段函数。 Python函数定义如下,使用Numpy:
(任何对此问题的 Sympy
起源感兴趣的人,请参阅下面的注释。)
最小示例
from numpy import select, less, nan
def f(t):
condlist = [less(t, 5), less(t, 15), less(t, 20), True]
choicelist = [1, 0, 1, 0]
return select(condlist, choicelist, default=nan)
请参阅下文以确认此功能在 Python 中有效。
问题:但是,Numba 无法在 nopython
模式下 JIT 此函数:
from numba import njit
jit_f = njit(f)
x = np.linspace(0, 50, 500)
jit_f(x)
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
Input In [86], in <cell line: 5>()
2 jit_f = njit(f)
4 x = np.linspace(0,50,500)
----> 5 jit_f(x)
File /usr/local/lib/python3.8/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
464 msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
465 f"by the following argument(s):\n{args_str}\n")
466 e.patch_message(msg)
--> 468 error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
471 error_rewrite(e, 'unsupported_error')
File /usr/local/lib/python3.8/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
407 raise e
408 else:
--> 409 raise e.with_traceback(None)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function select at 0x105f4d310>) found for signature:
>>> select(LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True))), list(int64)<iv=[1, 0, 1, 0]>, default=float64)
There are 2 candidate implementations:
- Of which 1 did not match due to:
Overload in function 'np_select': File: numba/np/arraymath.py: Line 4358.
With argument(s): '(Poison<LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True)))>, list(int64)<iv=None>, default=float64)':
Rejected as the implementation raised a specific error:
TypingError: Poison type used in arguments; got Poison<LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True)))>
raised from /usr/local/lib/python3.8/site-packages/numba/core/types/functions.py:236
- Of which 1 did not match due to:
Overload in function 'np_select': File: numba/np/arraymath.py: Line 4358.
With argument(s): '(LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True))), list(int64)<iv=[1, 0, 1, 0]>, default=float64)':
Rejected as the implementation raised a specific error:
NumbaTypeError: condlist must be a List or a Tuple
raised from /usr/local/lib/python3.8/site-packages/numba/np/arraymath.py:4375
During: resolving callee type: Function(<function select at 0x105f4d310>)
During: typing of call at /var/folders/dt/q6vbs0g56s70g4p2kyfj4tvh0000gn/T/ipykernel_61924/130570246.py (5)
File "../../../../../../../../../var/folders/dt/q6vbs0g56s70g4p2kyfj4tvh0000gn/T/ipykernel_61924/130570246.py", line 5:
<source missing, REPL/exec in use?>
我不是 Numba 专家,但我的感觉是存在一些语法错误。我玩过传递 Numpy 数组和 condlist
和 choicelist
的不同格式,但到目前为止没有运气。
其他注意事项
Python 函数的行为符合预期,在本例中给出了一些二元振荡然后为零:
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 50, 500)
plt.plot(x, f(x))
对于任何 Sympy 爱好者来说,这里最重要的问题是使用 Numba JIT 编译通过 Sympy 从 sympy.Piecewise
生成的 lambda。与上例中的 f(t)
非常相似的 lambda 可以由 sympy.lambdify
在分段函数上自动生成。
Numba 目前并未实现所有 Numpy 功能,有时支持有限。您可以找到支持的函数列表 in the documentation。对于 np.select
,文档指出支持仅限于:
only using homogeneous lists or tuples for the first two arguments,
condlist
andchoicelist
. Additionally, these two arguments can only contain arrays (unlike Numpy that also accepts tuples).
问题是 condlist
不是同构的,因为列表的前 3 项是数组,而最后一项是布尔值。此外,choicelist
包含整数,而它必须包含数组。
解决此问题的一种方法是使用以下代码:
def f(t):
condlist = [less(t, 5), less(t, 15), less(t, 20), np.full(t.size, True)]
all_zeros = np.zeros(t.size)
all_ones = np.ones(t.size)
choicelist = [all_ones, all_zeros, all_ones, all_zeros]
return select(condlist, choicelist, default=nan)
但是,请不要使用此代码,因为它效率低下。实际上,它创建了许多 临时数组,创建和填充速度很慢。代码肯定会受内存限制,而内存是一种稀缺资源,在过去几十年里才慢慢得到改善(这被称为 “内存墙”)。优化此类代码很难,Numba 在这方面并不比 Numpy 快。事实上,Numpy 已经非常有效地做到这一点,因为它是用 C 语言实现的,并且大多数功能都经过精心优化。当您使用循环并避免创建(无用的)临时数组时,Numba 速度很快。简而言之:Numba 喜欢循环 而不是 Numpy。这是一个更快的解决方案:
def f(t):
result = np.empty(t.size)
for i in range(t.size):
result[i] = t[i] < 5 or 15 <= t[i] < 20
return result
请注意,使用布尔值或短整数(例如 int8
)输出类型应该更快(floating-point 数字的计算速度很慢并且需要大量 space在内存中)。