numba 中范围迭代的更好模式?

Better pattern for range iteration in numba?

看看 numba 解压 rangexrange 的方式,很明显,其中发生了很多事情,这不仅仅等同于 for 循环。 例如:

import unittest
class FastOpsTest(unittest.TestCase):

    def test_numba_sum(self):
        arr = np.array([1,2,3], dtype=float)
        self.assertEquals(FastOpsTest.sum(arr), 6)
        FastOpsTest.sum.inspect_types()


    @staticmethod
    @jit(nopython=True, nogil=True)
    def sum(arr):
        N = len(arr)
        result = 0
        for i in xrange(N): # range() behaves similarly
            result += arr[i]
        return result

...下面给出了虚拟机代码。

for 循环实际上是在调用 python 范围和 xrange C 函数。更糟糕的是,看起来它在堆上分配内存(我假设这就是 dels 所做的)。 这似乎不是最理想的,尤其是对于嵌套循环。

除了手动重构之外:

i =0
while i!=N:
   ...
   i+=1

是否有更好的模式来优化 numba 中的循环?

# File: /home/userx/py/fast_ops.py
# --- LINE 176 --- 
# label 0
#   del [=12=].1
#   del [=12=].3
#   del $const0.4

@staticmethod

# --- LINE 177 --- 

@jit(nopython=True, nogil=True)

# --- LINE 178 --- 

def sum(arr):

    # --- LINE 179 --- 
    #   arr = arg(0, name=arr)  :: array(float64, 1d, C)
    #   [=12=].1 = global(len: <built-in function len>)  :: Function(<built-in function len>)
    #   [=12=].3 = call [=12=].1(arr, kws=[], args=[Var(arr, /home/userx/py/fast_ops.py (179))], func=[=12=].1, vararg=None)  :: (array(float64, 1d, C),) -> int64
    #   N = [=12=].3  :: int64

    N = len(arr)

    # --- LINE 180 --- 
    #   $const0.4 = const(int, 0)  :: int64
    #   result = $const0.4  :: float64
    #   jump 18
    # label 18

    result = 0

    # --- LINE 181 --- 
    #   jump 21
    # label 21
    #   .1 = global(xrange: <type 'xrange'>)  :: Function(<built-in function range>)
    #   .3 = call .1(N, kws=[], args=[Var(N, /home/userx/py/fast_ops.py (179))], func=.1, vararg=None)  :: (int64,) -> range_state_int64
    #   del N
    #   del .1
    #   .4 = getiter(value=.3)  :: range_iter_int64
    #   del .3
    #   $phi31.1 = .4  :: range_iter_int64
    #   del .4
    #   jump 31
    # label 31
    #   .2 = iternext(value=$phi31.1)  :: pair<int64, bool>
    #   .3 = pair_first(value=.2)  :: int64
    #   .4 = pair_second(value=.2)  :: bool
    #   del .2
    #   $phi54.1 = .3  :: int64
    #   del $phi54.1
    #   $phi54.2 = $phi31.1  :: range_iter_int64
    #   del $phi54.2
    #   $phi34.1 = .3  :: int64
    #   del .3
    #   branch .4, 34, 54
    # label 34
    #   del .4
    #   i = $phi34.1  :: int64
    #   del $phi34.1
    #   del i
    #   del .5
    #   del .6

    for i in xrange(N):

        # --- LINE 182 --- 
        #   .5 = getitem(index=i, value=arr)  :: float64
        #   .6 = inplace_binop(static_rhs=<object object at 0x7fb921f7cbc0>, rhs=.5, immutable_fn=+, lhs=result, static_lhs=<object object at 0x7fb921f7cbc0>, fn=+=)  :: float64
        #   result = .6  :: float64
        #   jump 31
        # label 54
        #   del arr
        #   del $phi34.1
        #   del $phi31.1
        #   del .4
        #   jump 55
        # label 55
        #   del result

        result += arr[i]

    # --- LINE 183 --- 
    #   .2 = cast(value=result)  :: float64
    #   return .2

    return result

inspect_types() returns Numba IR - 我不熟悉它,但我不认为有任何理由期望它映射接近实际执行的内容。

按照抽象的方式向下工作,您还可以使用 inspect_llvm() 方法查看 LLVM IR,并使用 inspect_asm() 查看实际执行的内容。在这种情况下,查看 LLVM IR 可以很清楚地将此编译为一个非常简单的 for 循环——我相信标签 B24: 是内部循环。

print(next(iter(FastOpsTest.sum.inspect_llvm().values())))

# some parts ommitted
define i32 @"_ZN8__main__11FastOpsTest7sum2E5ArrayIdLi1E1C7mutable7alignedE"(double* noalias nocapture %retptr, { i8*, i32 }** noalias nocapture readnone %excinfo, i8* noalias nocapture readnone %env, i8* nocapture readnone %arg.arr.0, i8* nocapture readnone %arg.arr.1, i64 %arg.arr.2, i64 %arg.arr.3, double* nocapture readonly %arg.arr.4, i64 %arg.arr.5.0, i64 %arg.arr.6.0) local_unnamed_addr #0 {
entry:
  %.98 = icmp sgt i64 %arg.arr.5.0, 0
  br i1 %.98, label %B24.preheader, label %B40

B24.preheader:                                    ; preds = %entry
  %0 = add i64 %arg.arr.5.0, 1
  br label %B24

B24:                                              ; preds = %B24.preheader, %B24
  %lsr.iv8 = phi double* [ %arg.arr.4, %B24.preheader ], [ %scevgep, %B24 ]
  %lsr.iv = phi i64 [ %0, %B24.preheader ], [ %lsr.iv.next, %B24 ]
  %result.07 = phi double [ %.250, %B24 ], [ 0.000000e+00, %B24.preheader ]
  %.242 = load double, double* %lsr.iv8, align 8
  %.250 = fadd double %result.07, %.242
  %lsr.iv.next = add i64 %lsr.iv, -1
  %scevgep = getelementptr double, double* %lsr.iv8, i64 1
  %.143 = icmp sgt i64 %lsr.iv.next, 1
  br i1 %.143, label %B24, label %B40

B40:                                              ; preds = %B24, %entry
  %result.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %.250, %B24 ]
  store double %result.0.lcssa, double* %retptr, align 8
  ret i32 0
}