numba njit 的混合数据类型输入
mix data type inputs for numba njit
我有一个大数组要运算,比如矩阵转置。 numba
快得多:
#test_transpose.py
import numpy as np
import numba as nb
import time
@nb.njit('float64[:,:](float64[:,:])', parallel=True)
def transpose(x):
r, c = x.shape
x2 = np.zeros((c, r))
for i in nb.prange(c):
for j in range(r):
x2[i, j] = x[j][i]
return x2
if __name__ == "__main__":
x = np.random.randn(int(3e6), 50)
t = time.time()
x = x.transpose().copy()
print(f"numpy transpose: {round(time.time() - t, 4)} secs")
x = np.random.randn(int(3e6), 50)
t = time.time()
x = transpose(x)
print(f"numba paralleled transpose: {round(time.time() - t, 4)} secs")
运行 在 Windows 命令提示符
D:\data\test>python test_transpose.py
numpy transpose: 2.0961 secs
numba paralleled transpose: 0.8584 secs
但是,我想输入另一个大矩阵,它是整数,使用x
作为
x = np.random.randint(int(3e6), size=(int(3e6), 50), dtype=np.int64)
异常被引发为
Traceback (most recent call last):
File "test_transpose.py", line 39, in <module>
x = transpose(x)
File "C:\Program Files\Python38\lib\site-packages\numba\core\dispatcher.py", line 703, in _explain_matching_error
raise TypeError(msg)
TypeError: No matching definition for argument type(s) array(int64, 2d, C)
它无法将输入数据矩阵识别为整数。如果我将整数矩阵的数据类型检查发布为
@nb.njit(parallel=True) # 'float64[:,:](float64[:,:])'
def transpose(x):
r, c = x.shape
x2 = np.zeros((c, r))
for i in nb.prange(c):
for j in range(r):
x2[i, j] = x[j][i]
return x2
比较慢:
D:\Data\test>python test_transpose.py
numba paralleled transpose: 1.6653 secs
正如预期的那样,对整数数据矩阵使用 @nb.njit('int64[:,:](int64[:,:])', parallel=True)
更快。
那么,我怎样才能在保持速度的同时允许混合数据类型输入,而不是为不同类型分别创建函数?
So, how can I still allow mixed data type intputs but keep the speed, instead of creating functions each for different types?
问题是 Numba 函数仅针对 float64
类型而不是 int64
定义。需要指定类型,因为 Numba 将 Python 代码编译为具有 well-defined 类型 的本机代码。您可以向 Numba 函数添加 多个签名 :
@nb.njit(['float64[:,:](float64[:,:])', 'int64[:,:](int64[:,:])'], parallel=True)
def transpose(x):
r, c = x.shape
# Specifying the dtype is very important here.
# This is a good habit to take to avoid numerical issues and
# slower performance in Numpy too.
x2 = np.zeros((c, r), dtype=x.dtype)
for i in nb.prange(c):
for j in range(r):
x2[i, j] = x[j][i]
return x2
It is slower
这是因为lazy compilation。第一次执行包括编译时间。指定签名时情况并非如此,因为使用了预编译。
numba
is much faster
好吧,考虑到使用了很多核心,这里就不多说了。事实上,朴素转置在大矩阵上效率非常低(在这种情况下,在大数组上浪费了大约 90% 的内存吞吐量)。有更快的算法。更多信息,请阅读(它只考虑in-place 2D square transposition,它更简单但思路是一样的)。另请注意,类型越宽,数组越大。数组越大转置越慢
我有一个大数组要运算,比如矩阵转置。 numba
快得多:
#test_transpose.py
import numpy as np
import numba as nb
import time
@nb.njit('float64[:,:](float64[:,:])', parallel=True)
def transpose(x):
r, c = x.shape
x2 = np.zeros((c, r))
for i in nb.prange(c):
for j in range(r):
x2[i, j] = x[j][i]
return x2
if __name__ == "__main__":
x = np.random.randn(int(3e6), 50)
t = time.time()
x = x.transpose().copy()
print(f"numpy transpose: {round(time.time() - t, 4)} secs")
x = np.random.randn(int(3e6), 50)
t = time.time()
x = transpose(x)
print(f"numba paralleled transpose: {round(time.time() - t, 4)} secs")
运行 在 Windows 命令提示符
D:\data\test>python test_transpose.py
numpy transpose: 2.0961 secs
numba paralleled transpose: 0.8584 secs
但是,我想输入另一个大矩阵,它是整数,使用x
作为
x = np.random.randint(int(3e6), size=(int(3e6), 50), dtype=np.int64)
异常被引发为
Traceback (most recent call last):
File "test_transpose.py", line 39, in <module>
x = transpose(x)
File "C:\Program Files\Python38\lib\site-packages\numba\core\dispatcher.py", line 703, in _explain_matching_error
raise TypeError(msg)
TypeError: No matching definition for argument type(s) array(int64, 2d, C)
它无法将输入数据矩阵识别为整数。如果我将整数矩阵的数据类型检查发布为
@nb.njit(parallel=True) # 'float64[:,:](float64[:,:])'
def transpose(x):
r, c = x.shape
x2 = np.zeros((c, r))
for i in nb.prange(c):
for j in range(r):
x2[i, j] = x[j][i]
return x2
比较慢:
D:\Data\test>python test_transpose.py
numba paralleled transpose: 1.6653 secs
正如预期的那样,对整数数据矩阵使用 @nb.njit('int64[:,:](int64[:,:])', parallel=True)
更快。
那么,我怎样才能在保持速度的同时允许混合数据类型输入,而不是为不同类型分别创建函数?
So, how can I still allow mixed data type intputs but keep the speed, instead of creating functions each for different types?
问题是 Numba 函数仅针对 float64
类型而不是 int64
定义。需要指定类型,因为 Numba 将 Python 代码编译为具有 well-defined 类型 的本机代码。您可以向 Numba 函数添加 多个签名 :
@nb.njit(['float64[:,:](float64[:,:])', 'int64[:,:](int64[:,:])'], parallel=True)
def transpose(x):
r, c = x.shape
# Specifying the dtype is very important here.
# This is a good habit to take to avoid numerical issues and
# slower performance in Numpy too.
x2 = np.zeros((c, r), dtype=x.dtype)
for i in nb.prange(c):
for j in range(r):
x2[i, j] = x[j][i]
return x2
It is slower
这是因为lazy compilation。第一次执行包括编译时间。指定签名时情况并非如此,因为使用了预编译。
numba
is much faster
好吧,考虑到使用了很多核心,这里就不多说了。事实上,朴素转置在大矩阵上效率非常低(在这种情况下,在大数组上浪费了大约 90% 的内存吞吐量)。有更快的算法。更多信息,请阅读