numba 在重塑 numpy 数组时出错

numba gives error when reshaping numpy array

我正在尝试优化一些具有循环和矩阵运算的代码。但是,我 运行 遇到了一些错误。请在下面找到代码和输出。

代码:

@njit
def list_of_distance(d1): #d1 was declared as List()
    list_of_dis = List()
    for k in range(len(d1)):
        sum_dist = List()
        for j in range(3):
            s = np.sum(square(np.reshape(d1[k][:,:,j].copy(),d1[k][:,:,j].shape[0]*d1[k][:,:,j].shape[1]))) 
            sum_dist.append(s) # square each value in the resulting list (dimenstion)   
        distance = np.sum(sum_dist) # adding the total value for each dimension to a list
        list_of_dis.append(np.round(np.sqrt(distance)))  # Sum the values to get the total squared values of residual images 

    return list_of_dis

输出:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function sum at 0x7f898814bd08>) with argument(s) of type(s): (list(int64))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function sum at 0x7f898814bd08>)
[2] During: typing of call at <ipython-input-18-8c787cc8deda> (7)


File "<ipython-input-18-8c787cc8deda>", line 7:
def list_of_distance(d1):
    <source elided>
        for j in range(3):
            s = np.sum(square(np.reshape(d1[k][:,:,j].copy(),d1[k][:,:,j].shape[0]*d1[k][:,:,j].shape[1]))) 
            ^

This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

To see Python/NumPy features supported by the latest release of Numba visit:
http://numba.pydata.org/numba-doc/latest/reference/pysupported.html
and
http://numba.pydata.org/numba-doc/latest/reference/numpysupported.html

For more information about typing errors and how to debug them visit:
http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile

If you think your code should work with Numba, please report the error message
and traceback, along with a minimal reproducer at:
https://github.com/numba/numba/issues/new

谁能帮我解决这个问题。

谢谢并致以最诚挚的问候

迈克尔

我必须进行一些更改才能使其正常工作并模拟 "d1",但这对我来说确实适用于 Numba。导致 运行time 错误的主要问题似乎是 np.sum 不适用于 Numba 列表,尽管当我注释掉 @jit 时它 运行 是正确的。用 np.array() 包装 sumdist 解决了这个问题。

d1 = [np.arange(27).reshape(3,3,3), np.arange(27,54).reshape(3,3,3)]

@njit
def list_of_distance(d1): #d1 was declared as List()
    list_of_dis = [] #List() Changed - would not compile
    for k in range(len(d1)):
        sum_dist = [] #List() #List() Changed - would not compile
        for j in range(3):
            s = np.sum(np.square(np.reshape(d1[k][:,:,j].copy(),d1[k][:,:,j].shape[0]*d1[k][:,:,j].shape[1]))) #Added np. to "square"
            sum_dist.append(s) # square each value in the resulting list (dimenstion)   
        distance = np.sum(np.array(sum_dist)) # adding the total value for each dimension to a list - Wrapped list in np.array
        list_of_dis.append(np.round(np.sqrt(distance)))  # Sum the values to get the total squared values of residual images 

    return list_of_dis

list_of_distance(d1)
Out[11]: [79.0, 212.0]