为什么我的 Numba JIT 函数被识别为数组?

Why is my Numba JIT function recognized as an array?

所以我正在尝试加速代码,虽然它适用于大多数功能,但有些功能不起作用。我指定了函数的签名,但它不起作用。如果我只写 nb.njit,它可以工作,但根本没有加速,甚至有轻微的减速。当按照我在下面发布的代码中指定签名时,出现以下错误:

TypeError: The decorated object is not a function (got type <class 'numba.core.types.npytypes.Array'>).


nb.njit(nb.int32[:],nb.int32[:],nb.int32[:],nb.int32[:,;](nb.int8,nb.int32[:,:],nb.int32[:,:],nb.int32[:,:],nb.int64,nb.float64[:,:])
def IdentifyVertsAndCells(i,cell_verts,vert_cells,vert_neighs,T1_edge,l):
    #Find vertices undergoing T1
    if T1_edge == len(l) - 1:
        T1_verts = cell_verts[i,[T1_edge,0]]
    else:
        T1_verts = cell_verts[i,[T1_edge,T1_edge+1]]
    
    #Identify the four cells are affected by transition
    dummy = np.concatenate((vert_cells[T1_verts[0]],vert_cells[T1_verts[1]]))
    T1cells = np.unique(dummy)

    #Identify cells that are neighbours prior to transition (that won't be afterwards)
    old_neigh = np.intersect1d(vert_cells[T1_verts[0]],vert_cells[T1_verts[1]])
    
    #Identify cells that will be neighbours after transition
    notneigh1 = T1cells[T1cells != old_neigh[0]]
    notneigh2 = T1cells[T1cells != old_neigh[1]]
    new_neigh = np.intersect1d(notneigh1,notneigh2)
    old_vert_neighs = vert_neighs[T1_verts,:]
    return T1_verts, old_neigh, new_neigh, old_vert_neighs

我检查了输入数组和数字的大小和数据类型,并确定我没有犯错。我想添加 int8 类型的数量,我不得不使用 j = np.asarray([i],dtype='int8')[0] 将 int 更改为 int8,因为我没有为 int 找到一个类型,但我为 int8 做了。我代码中的输入数字 i 对应于那个 j 并且确实是 int8 类型。当我只在函数上使用 inspect.isfunction 时,它会将其识别为函数。

调用上述函数的代码如下:

def UpdateTopology(points,verts,vert_neighs,vert_cells,cell_verts,l,x_max,y_max,T1_thresh,N):
    for i in range(N):
        #Determine how many vertices are in cell i
        vert_inds = cell_verts[i,:] < 2*N
        j = np.array([i]).astype('int8')[0]
        #If cell i only has three sides don't perform a T1 on it (you can't have a cell with 2 sides) 
        if(len(vert_inds) == 3):
            continue 
        
        #Find UP TO DATE vertex coords of cell i (can't use cell_vert_coords as
        #vertices will have changed due to previous T1 transitions)
        vert_inds = cell_verts[i,:] < 2*N
        cell_i = verts[cell_verts[i,vert_inds],:]
        
        #Shift vertex coords to account for periodicity
        rel_dists = cell_i - points[i,:]
        cell_i = ShiftCoords(cell_i,rel_dists,x_max,y_max,4)
        
        #Calculate the lengths, l, of each cell edge
        shifted_verts = np.roll(cell_i,-1,axis=0)
        l =  shifted_verts - cell_i
        l_mag = np.linalg.norm(l,axis=1)
        #Determine if any edges are below threshold length
        to_change = np.nonzero(l_mag < T1_thresh)
        #print('l = ',l)
        #print('T1_thresh = ', T1_thresh)
        if len(to_change[0]) == 0:
            continue
        else:
            T1_edge = to_change[0][0]
            #Identify vertices and cells affected, also return vert_neighs of old neighbours (to be used when updating vert_neighs later on)
            T1_verts, old_neigh, new_neigh, old_vert_neighs = T1f.IdentifyVertsAndCells(i,cell_verts,vert_cells,vert_neighs,T1_edge,l)
            #Update vertex coordinates
            verts = T1f.UpdateVertCoords(j,verts,points,cell_i,old_neigh,T1_verts,T1_thresh,l,T1_edge,x_max,y_max)    
            #Update vert_cells
            vert_cells = T1f.UpdateVertCells(verts,points,vert_cells,T1_verts,old_neigh,new_neigh)
            #Update cell_verts 
            cell_verts = T1f.UpdateCellVerts(verts,points,cell_verts,T1_verts,old_neigh,new_neigh,N)
            #Update vert_neighs
            vert_neighs = T1f.UpdateVertNeighs(vert_neighs,points,cell_verts,T1_verts,old_neigh,new_neigh,old_vert_neighs,N)

    return verts, vert_neighs, vert_cells, cell_verts

您遇到的错误与您提供给 njit 装饰器的 wrongly-defined 签名有关。每当你的 jitted 函数 return 有几个值时,你必须将 return 类型定义为同构元组(如果所有 return 类型都相同)或异构元组(如果 return 类型不同)(参见 回答)。

关于加速,您不会通过此代码示例获得任何加速:相反您会减速。我能识别的主要原因有以下两个:

  1. 您只是在使用已经高度优化的 numpy 标准功能。根据经验,numba 在加速循环方面效果很好。如果您的代码已经矢量化,您可能不会通过 njit调整您的函数来提高速度。
  2. 您使用了很多花哨的索引,例如cell_verts[i,[T1_edge,T1_edge+1]],它生成需要内存分配的数组副本,这是 numba 不太擅长的任务(参见 回答)。