numba.jit条件数组编译耗时较长

The compilation of an array of conditions with numba.jit takes a long time

如果我尝试使用 numba 的 jit-compiler 编译包含条件数组的函数,则需要很长时间。该程序看起来基本上像

from numba import jit
import numpy as np

@jit(nopython=True)
def foo(a, b):
    valid = [
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0),
        (a - 1 >= 0) and (b - 1 >= 0)
    ]

foo(1, 1)

我已经排除了不会显着改变编译时间的所有内容。如果我使用超过 20 个元素,就会出现问题。

| elements | time |
-------------------
|    21    | 2.7s |
|    22    | 5.1s |
|    23    |  10s |
|   ...    |  ... |
-------------------

尽管如此,该功能运行良好。有谁知道,为什么用 numba 编译这样的函数需要这么长时间?使用整数或浮点数的组合以类似的方式创建数组不会有任何问题。

  1. 您可能想在 numba issue tracker 上报告此问题,感觉编译器中出现了问题,导致其扩展性如此差。

  2. 你也可以考虑是否真的需要大量这样的数组语句,问题是否可以重构得更清楚。例如。 valid 可以是一个根据需要调用的函数,而不是布尔数组吗?

  3. 总而言之,当前版本的 numba 中的解决方法是展开条件。

以你的例子为例:

# "codegen"
for i in range(23):
    print(f'    valid[{i}] = (a - 1 >= 0) and (b - 1 >= 0)')

@jit(nopython=True)
def foo(a, b):
    valid = np.empty(23, dtype=np.bool_)
    valid[0] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[1] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[2] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[3] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[4] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[5] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[6] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[7] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[8] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[9] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[10] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[11] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[12] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[13] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[14] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[15] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[16] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[17] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[18] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[19] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[20] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[21] = (a - 1 >= 0) and (b - 1 >= 0)
    valid[22] = (a - 1 >= 0) and (b - 1 >= 0)

%time foo(1,1)
Wall time: 274 ms