Numba 函数需要很长时间才能为数组赋值

Numba function take long time for assign value to an array

我写了一个函数来计算 Numba 图像的 HOG,我 运行 它在 7000 张图像上。需要 10 秒的时间。但是当我评论将变量分配给数组 ( hist[idx] += mag ) 的行时,时间减少到 5 毫秒。这是什么问题,我应该怎么做。

@numba.jit( numba.uint64[:]( numba.uint8[:,:],numba.uint8), nopython=True )
def hog_numba( img, bins ):
    h,w = img.shape
    hist = np.zeros( bins, dtype=np.uint64)
    for i in range(h-1):
        for j in range(w-1):
            cy = img[i-1,j-1]*1 + img[i-1,j]*2 + img[i-1,j+1]*1 + img[i+1,j-1]*-1 + img[i+1,j]*-2 + img[i+1,j+1]*-1
            cx = img[i-1,j-1]*1 + img[i,j-1]*2 + img[i+1,j-1]*1 + img[i-1,j+1]*-1 + img[i,j+1]*-2 + img[i+1,j+1]*-1

            mag  =  numba.uint32(math.sqrt( math.pow(cx,2) + math.pow(cy,2) ) )

            if cx!=0:
                ang = math.atan2( cy, cx)#arc_tang
            else :
                if cy>0:
                    ang = math.pi / 2
                else:
                    ang = -math.pi / 2
                
            if ang<0:
                ang = abs(ang) + math.pi
            
            idx = (ang * bins) // (math.pi * 2 )
            idx = int(idx)

            #hist[idx] += mag

    
    return hist

以下用于基准测试的代码

for _ in range(20):
    print('start')
    t = time.time()
    hists = []
    for i in range(8000):
        hist = hog_numba(img, 10)
    t = time.time() - t
    print('time:',t)

速度上的差异不是因为赋值慢,而是因为JIT 编译器的优化。事实上,如果您注释行 hist[idx] += mag,那么 Numba 可以看到 magidx 不需要计算,并且可以删除关联的行。传递性地,它也可以去除angcxcy的计算。最后它可以完全删除两个嵌套循环。这样的代码会更快但也没有用。然而,JIT 在实践中可能不会完全删除两个嵌套循环内的所有操作,因为 JIT 可能由于 Python 转换、保护和副作用而无法完全优化代码。在我的机器上确实将循环优化为空操作。事实上,计算 8000 张大小为 (16_000,16_000) 的图像平均需要不到 1 毫秒,这在我的机器上是完全不可能的(它应该至少慢 1000 倍)。

因此,您无法通过仅删除独立指令来测量它的时间,并使用 Numba(或任何优化的编译代码)查找时间差异。现代编译器非常先进,想要打败它们并不容易。如果您仍想查看成本是否实际上主要来自赋值,您可以尝试执行求和,如 mag_sum += magidx_sum += idx 和 return/print 求和变量(否则编译器可以看到它们是无用的,因为它们不会引起可见的变化)。在我的机器上,赋值版本仅比使用求和的实现慢 9%,显示赋值不会占用大部分执行时间(尽管速度不是很快,可能是由于随机访问模式)。

减速的主要来源来自 (ang * bins) // (math.pi * 2 ) 行,更具体地说,来自 multiplication/division 常量 。提前在临时变量中预先计算 bins / (math.pi * 2) 会导致 3.5 倍快的代码 。代码远未优化。进一步的优化包括使用矢量化、无分支操作和并行性(使用简单精度并尝试删除 math.atan2 调用也可能有所帮助)。