为什么 python 中的 numba 库无法识别 numpy 二维数组

Why is numba library in python not recognizing numpy 2D array

我刚开始学习 numba,这里有一个练习来确定解决一个简单的矩阵问题所花费的时间。我的目标是使用 python numba 库实现该程序的并行执行 该程序有一个函数 create_matrix(row: int, col: int),它接受两个输入并创建一个二维矩阵。然后我创建两个矩阵,找到它们的主要对角线的总和并计算它们的总数。 问题是 numba 似乎不理解 numpy 二维数组。任何帮助将不胜感激。谢谢

#imports
from numba import njit
import numpy as np

#create a 2D matrix by consecutive numbers - starting 1
def create_matrix(row, col):
    arr = np.array([[j + (col * i) for j in range(1, col + 1)] for i in range(row)])

    return np.matrix(arr)

# calculate the sum of primary diagonals of matrix1
jitted_function = njit()(create_matrix)
m1 = jitted_function(1, 1)
print(f"Matrix 1 : {m1}")
print(f"Matrix 1 diagonal: {np.diagonal(m1)}")
print(f"Matrix 1 sum of primary diagonal is : {np.trace(m1)}")
mat1_sum = np.trace(m1, dtype='i')


# calculate the sum of primary diagonals of matrix2
m2 = create_matrix(4, 4)
print(f"Matrix 2 : {m2}")
print(f"Matrix 2 diagonal : {np.diagonal(m2)}")
print(f"Matrix 2 Sum of diagonal is : {np.trace(m2)}")
mat2_sum = np.trace(m2, dtype='i')

sum_of_two_diagonals = mat1_sum + mat2_sum
print(f"THE SUM IS :  {sum_of_two_diagonals}")

错误是

Traceback (most recent call last):
  File "E:\Users\SoniTech\PycharmProjects\computer_hardware\practise.py", line 21, in <module>
    m1 = jitted_function(1, 1)
  File "E:\Users\SoniTech\PycharmProjects\computer_hardware\venv\lib\site-packages\numba\core\dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "E:\Users\SoniTech\PycharmProjects\computer_hardware\venv\lib\site-packages\numba\core\dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function setitem>) found for signature:
 
 >>> setitem(array(undefined, 1d, C), int64, array(int64, 1d, C))
 
There are 16 candidate implementations:
   - Of which 16 did not match due to:
   Overload of function 'setitem': File: <numerous>: Line N/A.
     With argument(s): '(array(undefined, 1d, C), int64, array(int64, 1d, C))':
    No match.

During: typing of setitem at E:\Users\SoniTech\PycharmProjects\computer_hardware\practise.py (10)

File "practise.py", line 10:
def create_matrix(row, col):
    <source elided>
    """
    arr = np.array([[j + (col * i) for j in range(1, col + 1)] for i in range(row)])  # create a matrix starting 1
    ^

njit 模式不支持您的某些函数,例如 np.matrix()。您可以改为重写 create_matrix() 函数,以便 numba 可以委托给它自己的函数。

#imports
from numba import njit
import numpy as np

#create a 2D matrix by consecutive numbers - starting 1
def create_matrix(row, col):
    arr = np.zeros((row, col))
    for i in range(row):
        for j in range(1, col + 1):
            arr[i,j-1] = j + (col * i)
    return arr

# calculate the sum of primary diagonals of matrix1
jitted_function = njit()(create_matrix)