numba njit 的混合数据类型输入

mix data type inputs for numba njit

我有一个大数组要运算,比如矩阵转置。 numba 快得多:

#test_transpose.py
import numpy as np
import numba as nb
import time


@nb.njit('float64[:,:](float64[:,:])', parallel=True)
def transpose(x):
    r, c = x.shape
    x2 = np.zeros((c, r))
    for i in nb.prange(c):
        for j in range(r):
            x2[i, j] = x[j][i]
   return x2


if __name__ == "__main__":
    x = np.random.randn(int(3e6), 50)
    t = time.time()
    x = x.transpose().copy()
    print(f"numpy transpose: {round(time.time() - t, 4)} secs")

    x = np.random.randn(int(3e6), 50)
    t = time.time()
    x = transpose(x)
    print(f"numba paralleled transpose: {round(time.time() - t, 4)} secs")

运行 在 Windows 命令提示符

D:\data\test>python test_transpose.py
numpy transpose: 2.0961 secs
numba paralleled transpose: 0.8584 secs

但是,我想输入另一个大矩阵,它是整数,使用x作为

x = np.random.randint(int(3e6), size=(int(3e6), 50), dtype=np.int64)

异常被引发为

Traceback (most recent call last):
  File "test_transpose.py", line 39, in <module>
    x = transpose(x)
  File "C:\Program Files\Python38\lib\site-packages\numba\core\dispatcher.py", line 703, in _explain_matching_error
    raise TypeError(msg)
TypeError: No matching definition for argument type(s) array(int64, 2d, C)

它无法将输入数据矩阵识别为整数。如果我将整数矩阵的数据类型检查发布为

@nb.njit(parallel=True) # 'float64[:,:](float64[:,:])'
def transpose(x):
    r, c = x.shape
    x2 = np.zeros((c, r))
    for i in nb.prange(c):
        for j in range(r):
            x2[i, j] = x[j][i]
    return x2

比较慢:

D:\Data\test>python test_transpose.py
numba paralleled transpose: 1.6653 secs

正如预期的那样,对整数数据矩阵使用 @nb.njit('int64[:,:](int64[:,:])', parallel=True) 更快。

那么,我怎样才能在保持速度的同时允许混合数据类型输入,而不是为不同类型分别创建函数?

So, how can I still allow mixed data type intputs but keep the speed, instead of creating functions each for different types?

问题是 Numba 函数仅针对 float64 类型而不是 int64 定义。需要指定类型,因为 Numba 将 Python 代码编译为具有 well-defined 类型 的本机代码。您可以向 Numba 函数添​​加 多个签名

@nb.njit(['float64[:,:](float64[:,:])', 'int64[:,:](int64[:,:])'], parallel=True)
def transpose(x):
    r, c = x.shape
    # Specifying the dtype is very important here.
    # This is a good habit to take to avoid numerical issues and 
    # slower performance in Numpy too.
    x2 = np.zeros((c, r), dtype=x.dtype)
    for i in nb.prange(c):
        for j in range(r):
            x2[i, j] = x[j][i]
   return x2

It is slower

这是因为lazy compilation。第一次执行包括编译时间。指定签名时情况并非如此,因为使用了预编译。

numba is much faster

好吧,考虑到使用了很多核心,这里就不多说了。事实上,朴素转置在大矩阵上效率非常低(在这种情况下,在大数组上浪费了大约 90% 的内存吞吐量)。有更快的算法。更多信息,请阅读(它只考虑in-place 2D square transposition,它更简单但思路是一样的)。另请注意,类型越宽,数组越大。数组越大转置越慢