Numba 没有提高性能

Numba is not enhancing the performance

我正在测试一些采用 numpy 数组的函数的 numba 性能,并比较:

import numpy as np
from numba import jit, vectorize, float64
import time
from numba.core.errors import NumbaWarning
import warnings

warnings.simplefilter('ignore', category=NumbaWarning)

@jit(nopython=True, boundscheck=False) # Set "nopython" mode for best performance, equivalent to @njit
def go_fast(a):     # Function is compiled to machine code when called the first time
    trace = 0.0
    for i in range(a.shape[0]):   # Numba likes loops
        trace += np.tanh(a[i, i]) # Numba likes NumPy functions
    return a + trace              # Numba likes NumPy broadcasting
   
class Main(object):
    def __init__(self) -> None:
        super().__init__()
        self.mat     = np.arange(100000000, dtype=np.float64).reshape(10000, 10000)

    def my_run(self):
        st = time.time()
        trace = 0.0
        for i in range(self.mat.shape[0]):   
            trace += np.tanh(self.mat[i, i]) 
        res = self.mat + trace
        print('Python Diration: ', time.time() - st)
        return res                           
    
    def jit_run(self):
        st = time.time()
        res = go_fast(self.mat)
        print('Jit Diration: ', time.time() - st)
        return res
        
obj = Main()
x1 = obj.my_run()
x2 = obj.jit_run()

输出为:

Python Diration:  0.2164750099182129
Jit Diration:  0.5367801189422607

如何获得此示例的增强版本?

Numba 实现的较慢执行时间是由于 编译时间 因为 Numba 在使用函数时编译函数(只有第一次,除非函数的类型参数改变)。它这样做是因为它无法在调用函数之前知道参数的类型。希望您可以 为 Numba 指定参数类型 以便它可以直接编译函数(当执行装饰器函数时)。这是结果代码:

@njit('float64[:,:](float64[:,:])')
def go_fast(a):
    trace = 0.0
    for i in range(a.shape[0]):
        trace += np.tanh(a[i, i])
    return a + trace

请注意,njitjit+nopython=True 的快捷方式,并且 boundscheck 已默认设置为 False(请参阅 doc).

在我的机器上,这导致 Numpy 和 Numba 的执行时间相同。实际上,执行时间不受 tanh 函数计算的限制。它 受表达式 a + trace 的限制(对于 Numba 和 Numpy)。预计执行时间相同,因为两者的实现方式相同:它们创建一个临时的新数组来执行加法。由于在 x86 平台上 page faults and the use of the RAM (a is fully read from the RAM and the temporary array is fully stored in RAM). If you want a faster computation, then you need to perform the operation in-place (this prevent page faults and expensive cache-line write allocations,因此创建新的临时数组非常昂贵。