为什么 python 中的 numba 库无法识别 numpy 二维数组
Why is numba library in python not recognizing numpy 2D array
我刚开始学习 numba,这里有一个练习来确定解决一个简单的矩阵问题所花费的时间。我的目标是使用 python numba 库实现该程序的并行执行 该程序有一个函数 create_matrix(row: int, col: int),它接受两个输入并创建一个二维矩阵。然后我创建两个矩阵,找到它们的主要对角线的总和并计算它们的总数。
问题是 numba 似乎不理解 numpy 二维数组。任何帮助将不胜感激。谢谢
#imports
from numba import njit
import numpy as np
#create a 2D matrix by consecutive numbers - starting 1
def create_matrix(row, col):
arr = np.array([[j + (col * i) for j in range(1, col + 1)] for i in range(row)])
return np.matrix(arr)
# calculate the sum of primary diagonals of matrix1
jitted_function = njit()(create_matrix)
m1 = jitted_function(1, 1)
print(f"Matrix 1 : {m1}")
print(f"Matrix 1 diagonal: {np.diagonal(m1)}")
print(f"Matrix 1 sum of primary diagonal is : {np.trace(m1)}")
mat1_sum = np.trace(m1, dtype='i')
# calculate the sum of primary diagonals of matrix2
m2 = create_matrix(4, 4)
print(f"Matrix 2 : {m2}")
print(f"Matrix 2 diagonal : {np.diagonal(m2)}")
print(f"Matrix 2 Sum of diagonal is : {np.trace(m2)}")
mat2_sum = np.trace(m2, dtype='i')
sum_of_two_diagonals = mat1_sum + mat2_sum
print(f"THE SUM IS : {sum_of_two_diagonals}")
错误是
Traceback (most recent call last):
File "E:\Users\SoniTech\PycharmProjects\computer_hardware\practise.py", line 21, in <module>
m1 = jitted_function(1, 1)
File "E:\Users\SoniTech\PycharmProjects\computer_hardware\venv\lib\site-packages\numba\core\dispatcher.py", line 468, in _compile_for_args
error_rewrite(e, 'typing')
File "E:\Users\SoniTech\PycharmProjects\computer_hardware\venv\lib\site-packages\numba\core\dispatcher.py", line 409, in error_rewrite
raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
>>> setitem(array(undefined, 1d, C), int64, array(int64, 1d, C))
There are 16 candidate implementations:
- Of which 16 did not match due to:
Overload of function 'setitem': File: <numerous>: Line N/A.
With argument(s): '(array(undefined, 1d, C), int64, array(int64, 1d, C))':
No match.
During: typing of setitem at E:\Users\SoniTech\PycharmProjects\computer_hardware\practise.py (10)
File "practise.py", line 10:
def create_matrix(row, col):
<source elided>
"""
arr = np.array([[j + (col * i) for j in range(1, col + 1)] for i in range(row)]) # create a matrix starting 1
^
njit 模式不支持您的某些函数,例如 np.matrix()
。您可以改为重写 create_matrix()
函数,以便 numba 可以委托给它自己的函数。
#imports
from numba import njit
import numpy as np
#create a 2D matrix by consecutive numbers - starting 1
def create_matrix(row, col):
arr = np.zeros((row, col))
for i in range(row):
for j in range(1, col + 1):
arr[i,j-1] = j + (col * i)
return arr
# calculate the sum of primary diagonals of matrix1
jitted_function = njit()(create_matrix)
我刚开始学习 numba,这里有一个练习来确定解决一个简单的矩阵问题所花费的时间。我的目标是使用 python numba 库实现该程序的并行执行 该程序有一个函数 create_matrix(row: int, col: int),它接受两个输入并创建一个二维矩阵。然后我创建两个矩阵,找到它们的主要对角线的总和并计算它们的总数。 问题是 numba 似乎不理解 numpy 二维数组。任何帮助将不胜感激。谢谢
#imports
from numba import njit
import numpy as np
#create a 2D matrix by consecutive numbers - starting 1
def create_matrix(row, col):
arr = np.array([[j + (col * i) for j in range(1, col + 1)] for i in range(row)])
return np.matrix(arr)
# calculate the sum of primary diagonals of matrix1
jitted_function = njit()(create_matrix)
m1 = jitted_function(1, 1)
print(f"Matrix 1 : {m1}")
print(f"Matrix 1 diagonal: {np.diagonal(m1)}")
print(f"Matrix 1 sum of primary diagonal is : {np.trace(m1)}")
mat1_sum = np.trace(m1, dtype='i')
# calculate the sum of primary diagonals of matrix2
m2 = create_matrix(4, 4)
print(f"Matrix 2 : {m2}")
print(f"Matrix 2 diagonal : {np.diagonal(m2)}")
print(f"Matrix 2 Sum of diagonal is : {np.trace(m2)}")
mat2_sum = np.trace(m2, dtype='i')
sum_of_two_diagonals = mat1_sum + mat2_sum
print(f"THE SUM IS : {sum_of_two_diagonals}")
错误是
Traceback (most recent call last):
File "E:\Users\SoniTech\PycharmProjects\computer_hardware\practise.py", line 21, in <module>
m1 = jitted_function(1, 1)
File "E:\Users\SoniTech\PycharmProjects\computer_hardware\venv\lib\site-packages\numba\core\dispatcher.py", line 468, in _compile_for_args
error_rewrite(e, 'typing')
File "E:\Users\SoniTech\PycharmProjects\computer_hardware\venv\lib\site-packages\numba\core\dispatcher.py", line 409, in error_rewrite
raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
>>> setitem(array(undefined, 1d, C), int64, array(int64, 1d, C))
There are 16 candidate implementations:
- Of which 16 did not match due to:
Overload of function 'setitem': File: <numerous>: Line N/A.
With argument(s): '(array(undefined, 1d, C), int64, array(int64, 1d, C))':
No match.
During: typing of setitem at E:\Users\SoniTech\PycharmProjects\computer_hardware\practise.py (10)
File "practise.py", line 10:
def create_matrix(row, col):
<source elided>
"""
arr = np.array([[j + (col * i) for j in range(1, col + 1)] for i in range(row)]) # create a matrix starting 1
^
njit 模式不支持您的某些函数,例如 np.matrix()
。您可以改为重写 create_matrix()
函数,以便 numba 可以委托给它自己的函数。
#imports
from numba import njit
import numpy as np
#create a 2D matrix by consecutive numbers - starting 1
def create_matrix(row, col):
arr = np.zeros((row, col))
for i in range(row):
for j in range(1, col + 1):
arr[i,j-1] = j + (col * i)
return arr
# calculate the sum of primary diagonals of matrix1
jitted_function = njit()(create_matrix)