为什么 numba 在 numpy linspace 中引发类型错误

Why numba raise a type error in numpy linspace

我正在使用 numba 0.34.0 和 numpy 1.13.1。一个小例子如下:

import numpy as np    
from numba import jit
@jit(nopython=True)
def initial(veh_size):
    t = np.linspace(0, (100 - 1) * 30, 100, dtype=np.int32)
    t0 = np.linspace(0, (veh_size - 1) * 30, veh_size, dtype=np.int32)
    return t0

initial(100)

带有 tt0 的行都有相同的错误消息。

错误信息:

numba.errors.InternalError: 
[1] During: resolving callee type: Function(<function linspace at 0x000001F977678C80>)
[2] During: typing of call at ***/test1.py (6)

因为 np.linspace 的 numba 版本不接受 dtype 参数 (source: numba 0.34 documentation):

2.7.3.3. Other functions

The following top-level functions are supported:

  • [...]

  • numpy.linspace() (only the 3-argument form)

  • [...]

您需要使用 astype 在 nopython-numba 函数中转换它:

import numpy as np    
from numba import jit
@jit(nopython=True)
def initial(veh_size):
    t = np.linspace(0, (100 - 1) * 30, 100).astype(np.int32)
    t0 = np.linspace(0, (veh_size - 1) * 30, veh_size).astype(np.int32)
    return t0

initial(100)

或者只是不要在 nopython-numba 函数中使用 np.linspace 并将其作为参数传递。这避免了临时数组,我怀疑 numbas np.linspace 比 NumPys.