Numba 没有提高性能
Numba is not enhancing the performance
我正在测试一些采用 numpy
数组的函数的 numba 性能,并比较:
import numpy as np
from numba import jit, vectorize, float64
import time
from numba.core.errors import NumbaWarning
import warnings
warnings.simplefilter('ignore', category=NumbaWarning)
@jit(nopython=True, boundscheck=False) # Set "nopython" mode for best performance, equivalent to @njit
def go_fast(a): # Function is compiled to machine code when called the first time
trace = 0.0
for i in range(a.shape[0]): # Numba likes loops
trace += np.tanh(a[i, i]) # Numba likes NumPy functions
return a + trace # Numba likes NumPy broadcasting
class Main(object):
def __init__(self) -> None:
super().__init__()
self.mat = np.arange(100000000, dtype=np.float64).reshape(10000, 10000)
def my_run(self):
st = time.time()
trace = 0.0
for i in range(self.mat.shape[0]):
trace += np.tanh(self.mat[i, i])
res = self.mat + trace
print('Python Diration: ', time.time() - st)
return res
def jit_run(self):
st = time.time()
res = go_fast(self.mat)
print('Jit Diration: ', time.time() - st)
return res
obj = Main()
x1 = obj.my_run()
x2 = obj.jit_run()
输出为:
Python Diration: 0.2164750099182129
Jit Diration: 0.5367801189422607
如何获得此示例的增强版本?
Numba 实现的较慢执行时间是由于 编译时间 因为 Numba 在使用函数时编译函数(只有第一次,除非函数的类型参数改变)。它这样做是因为它无法在调用函数之前知道参数的类型。希望您可以 为 Numba 指定参数类型 以便它可以直接编译函数(当执行装饰器函数时)。这是结果代码:
@njit('float64[:,:](float64[:,:])')
def go_fast(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace
请注意,njit
是 jit
+nopython=True
的快捷方式,并且 boundscheck
已默认设置为 False
(请参阅 doc).
在我的机器上,这导致 Numpy 和 Numba 的执行时间相同。实际上,执行时间不受 tanh
函数计算的限制。它 受表达式 a + trace
的限制(对于 Numba 和 Numpy)。预计执行时间相同,因为两者的实现方式相同:它们创建一个临时的新数组来执行加法。由于在 x86 平台上 page faults and the use of the RAM (a
is fully read from the RAM and the temporary array is fully stored in RAM). If you want a faster computation, then you need to perform the operation in-place (this prevent page faults and expensive cache-line write allocations,因此创建新的临时数组非常昂贵。
我正在测试一些采用 numpy
数组的函数的 numba 性能,并比较:
import numpy as np
from numba import jit, vectorize, float64
import time
from numba.core.errors import NumbaWarning
import warnings
warnings.simplefilter('ignore', category=NumbaWarning)
@jit(nopython=True, boundscheck=False) # Set "nopython" mode for best performance, equivalent to @njit
def go_fast(a): # Function is compiled to machine code when called the first time
trace = 0.0
for i in range(a.shape[0]): # Numba likes loops
trace += np.tanh(a[i, i]) # Numba likes NumPy functions
return a + trace # Numba likes NumPy broadcasting
class Main(object):
def __init__(self) -> None:
super().__init__()
self.mat = np.arange(100000000, dtype=np.float64).reshape(10000, 10000)
def my_run(self):
st = time.time()
trace = 0.0
for i in range(self.mat.shape[0]):
trace += np.tanh(self.mat[i, i])
res = self.mat + trace
print('Python Diration: ', time.time() - st)
return res
def jit_run(self):
st = time.time()
res = go_fast(self.mat)
print('Jit Diration: ', time.time() - st)
return res
obj = Main()
x1 = obj.my_run()
x2 = obj.jit_run()
输出为:
Python Diration: 0.2164750099182129
Jit Diration: 0.5367801189422607
如何获得此示例的增强版本?
Numba 实现的较慢执行时间是由于 编译时间 因为 Numba 在使用函数时编译函数(只有第一次,除非函数的类型参数改变)。它这样做是因为它无法在调用函数之前知道参数的类型。希望您可以 为 Numba 指定参数类型 以便它可以直接编译函数(当执行装饰器函数时)。这是结果代码:
@njit('float64[:,:](float64[:,:])')
def go_fast(a):
trace = 0.0
for i in range(a.shape[0]):
trace += np.tanh(a[i, i])
return a + trace
请注意,njit
是 jit
+nopython=True
的快捷方式,并且 boundscheck
已默认设置为 False
(请参阅 doc).
在我的机器上,这导致 Numpy 和 Numba 的执行时间相同。实际上,执行时间不受 tanh
函数计算的限制。它 受表达式 a + trace
的限制(对于 Numba 和 Numpy)。预计执行时间相同,因为两者的实现方式相同:它们创建一个临时的新数组来执行加法。由于在 x86 平台上 page faults and the use of the RAM (a
is fully read from the RAM and the temporary array is fully stored in RAM). If you want a faster computation, then you need to perform the operation in-place (this prevent page faults and expensive cache-line write allocations,因此创建新的临时数组非常昂贵。