Numba:索引向量会出错

Numba: indexing a vector is giving an error

我最近开始使用 python 和 numba。我的问题是:我有一个矩阵(n行m列)。在for循环中我必须改变特定列的值。

没有 numba,代码 运行 没问题。但是当我使用 njit() 时,它就崩溃了。

注意:在我的实际项目中,每一行的值都不相同。

这是我需要的示例。

import numpy as np


def func_test(matrix_size, columns_idx, replace_values):
    matrix = np.zeros((matrix_size, matrix_size), dtype='int32')

    for i in range(0, matrix_size):
        matrix[i, columns_idx] = replace_values
    return matrix


columns_idx = [0,2,4]
replace_values = [1, 3, 5]

new_matrix = func_test(5, columns_idx, replace_values)
print(new_matrix)

[[1 0 3 0 5]
 [1 0 3 0 5]
 [1 0 3 0 5]
 [1 0 3 0 5]
 [1 0 3 0 5]]

现在,我将应用 Numba。

import numpy as np
from numba import njit


@njit()
def func_test(matrix_size, columns_idx, replace_values):
    matrix = np.zeros((matrix_size, matrix_size), dtype='int32')

    for i in range(0, matrix_size):
        matrix[i, columns_idx] = replace_values
    return matrix


columns_idx = [0,2,4]
replace_values = [1, 3, 5]

new_matrix = func_test(5, columns_idx, replace_values)
print(new_matrix)

并出现此错误:

No implementation of function Function(<built-in function setitem>) found for signature:

\>>> setitem(array(int32, 2d, C), Tuple(int64, reflected list(int64)<iv=None>), reflected list(int64)<iv=None>)

There are 16 candidate implementations:

  • Of which 14 did not match due to: Overload of function 'setitem': File: <numerous>: Line N/A. With argument(s): '(array(int32, 2d, C), Tuple(int64, reflected list(int64)<iv=None>), reflected list(int64)<iv=None>)': No match.
  • Of which 2 did not match due to: Overload in function 'SetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 176. With argument(s): '(array(int32, 2d, C), Tuple(int64, reflected list(int64)<iv=None>), reflected list(int64)<iv=None>)': Rejected as the implementation raised a specific error: NumbaTypeError: unsupported array index type reflected list(int64)<iv=None> in Tuple(int64, reflected list(int64)<iv=None>) raised from /Users/goncaloguedes/.conda/envs/Python_Fold_UnFold/lib/python3.9/site-packages/numba/core/typing/arraydecl.py:72

During: typing of setitem at /Users/goncaloguedes/Desktop/Fold_Unfold/Python_Fold_UnFold/Testes.py (10)

File "Testes.py", line 10: def func_test(matrix_size, columns_idx, replace_values): <source elided> for i in range(0, matrix_size): matrix[i, columns_idx] = replace_values

Numba 不能很好地支持 反射列表(CPython 列表可能包含不同类型的对象)。您需要使用 类型列表 Numpy 数组 。原因是 Numba 需要处理 well-defined 类型的变量而不是 dynamically-typed 对象以便高效(并且实际上能够在本机二进制文件中编译代码)。在您的情况下,最好使用 Numpy 数组。

此外,Numba 似乎无法编译像 matrix[i, columns_idx] = replace_values 这样的行。 Numba 尚未支持所有 Numpy 功能。因此,您需要改用基本循环。

这是更正代码的示例:

import numpy as np
from numba import njit


@njit()
def func_test(matrix_size, columns_idx, replace_values):
    matrix = np.zeros((matrix_size, matrix_size), dtype=np.int32)

    for i in range(0, matrix_size):
        for j in range(0, columns_idx.size):
            matrix[i, columns_idx[j]] = replace_values[j]
    return matrix


columns_idx = np.array([0, 2, 4], np.int64)
replace_values = np.array([1, 3, 5], np.int32)

new_matrix = func_test(5, columns_idx, replace_values)
print(new_matrix)