Numba 无法在 nopython 模式下编译基于 np.select 的函数

Numba fails to compile np.select based function in nopython mode

我正在尝试使用 numba.njit 编译有效的分段函数。 Python函数定义如下,使用Numpy:

(任何对此问题的 Sympy 起源感兴趣的人,请参阅下面的注释。)

最小示例

from numpy import select, less, nan
def f(t):
    condlist = [less(t, 5), less(t, 15), less(t, 20), True]
    choicelist = [1, 0, 1, 0]
    return select(condlist, choicelist, default=nan)

请参阅下文以确认此功能在 Python 中有效。

问题:但是,Numba 无法在 nopython 模式下 JIT 此函数:

from numba import njit
jit_f = njit(f)

x = np.linspace(0, 50, 500)
jit_f(x)

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Input In [86], in <cell line: 5>()
      2 jit_f = njit(f)
      4 x = np.linspace(0,50,500)
----> 5 jit_f(x)

File /usr/local/lib/python3.8/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File /usr/local/lib/python3.8/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function select at 0x105f4d310>) found for signature:
 
 >>> select(LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True))), list(int64)<iv=[1, 0, 1, 0]>, default=float64)
 
There are 2 candidate implementations:
      - Of which 1 did not match due to:
      Overload in function 'np_select': File: numba/np/arraymath.py: Line 4358.
        With argument(s): '(Poison<LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True)))>, list(int64)<iv=None>, default=float64)':
       Rejected as the implementation raised a specific error:
         TypingError: Poison type used in arguments; got Poison<LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True)))>
  raised from /usr/local/lib/python3.8/site-packages/numba/core/types/functions.py:236
      - Of which 1 did not match due to:
      Overload in function 'np_select': File: numba/np/arraymath.py: Line 4358.
        With argument(s): '(LiteralList((array(bool, 1d, C), array(bool, 1d, C), array(bool, 1d, C), Literal[bool](True))), list(int64)<iv=[1, 0, 1, 0]>, default=float64)':
       Rejected as the implementation raised a specific error:
         NumbaTypeError: condlist must be a List or a Tuple
  raised from /usr/local/lib/python3.8/site-packages/numba/np/arraymath.py:4375

During: resolving callee type: Function(<function select at 0x105f4d310>)
During: typing of call at /var/folders/dt/q6vbs0g56s70g4p2kyfj4tvh0000gn/T/ipykernel_61924/130570246.py (5)


File "../../../../../../../../../var/folders/dt/q6vbs0g56s70g4p2kyfj4tvh0000gn/T/ipykernel_61924/130570246.py", line 5:
<source missing, REPL/exec in use?>

我不是 Numba 专家,但我的感觉是存在一些语法错误。我玩过传递 Numpy 数组和 condlistchoicelist 的不同格式,但到目前为止没有运气。

其他注意事项

Python 函数的行为符合预期,在本例中给出了一些二元振荡然后为零:

import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, 50, 500)
plt.plot(x, f(x))

对于任何 Sympy 爱好者来说,这里最重要的问题是使用 Numba JIT 编译通过 Sympy 从 sympy.Piecewise 生成的 lambda。与上例中的 f(t) 非常相似的 lambda 可以由 sympy.lambdify 在分段函数上自动生成。

Numba 目前并未实现所有 Numpy 功能,有时支持有限。您可以找到支持的函数列表 in the documentation。对于 np.select,文档指出支持仅限于:

only using homogeneous lists or tuples for the first two arguments, condlist and choicelist. Additionally, these two arguments can only contain arrays (unlike Numpy that also accepts tuples).

问题是 condlist 不是同构的,因为列表的前 3 项是数组,而最后一项是布尔值。此外,choicelist 包含整数,而它必须包含数组。

解决此问题的一种方法是使用以下代码:

def f(t):
    condlist = [less(t, 5), less(t, 15), less(t, 20), np.full(t.size, True)]
    all_zeros = np.zeros(t.size)
    all_ones = np.ones(t.size)
    choicelist = [all_ones, all_zeros, all_ones, all_zeros]
    return select(condlist, choicelist, default=nan)

但是,请不要使用此代码,因为它效率低下。实际上,它创建了许多 临时数组,创建和填充速度很慢。代码肯定会受内存限制,而内存是一种稀缺资源,在过去几十年里才慢慢得到改善(这被称为 “内存墙”)。优化此类代码很难,Numba 在这方面并不比 Numpy 快。事实上,Numpy 已经非常有效地做到这一点,因为它是用 C 语言实现的,并且大多数功能都经过精心优化。当您使用循环并避免创建(无用的)临时数组时,Numba 速度很快。简而言之:Numba 喜欢循环 而不是 Numpy。这是一个更快的解决方案:

def f(t):
    result = np.empty(t.size)
    for i in range(t.size):
        result[i] = t[i] < 5 or 15 <= t[i] < 20
    return result

请注意,使用布尔值或短整数(例如 int8)输出类型应该更快(floating-point 数字的计算速度很慢并且需要大量 space在内存中)。