numba 编译函数中范围函数的行为
Behavior of range function in numba-compiled functions
我刚刚意识到同时使用 jit
装饰器和 range
函数会出现奇怪的行为。比长篇大论更好,请考虑以下简单代码:
@nb.njit(['float64[:,:](float64[:,:], float64[:,:], int32, int32)'])
def range1(a, b, nx, nz):
for ix in range(5, nx-5):
for iz in range(5, nz-5):
b[ix, iz] = 0.5*(a[ix+1, iz+1] - a[ix-1, iz-1])
return b
@nb.njit(['float64[:,:](float64[:,:], float64[:,:], int32, int32, int32, int32)'])
def range2(a, b, ix1, ix2, iz1, iz2):
for ix in range(ix1, ix2):
for iz in range(iz1, iz2):
b[ix, iz] = 0.5*(a[ix+1, iz+1] - a[ix-1, iz-1])
return b
@nb.njit(['float64[:,:](float64[:,:], float64[:,:], int32, int32, int32, int32)'])
def range3(a, b, ix1, ix2, iz1, iz2):
for ix in range(ix1, ix2):
for iz in range(5, iz2):
b[ix, iz] = 0.5*(a[ix+1, iz+1] - a[ix-1, iz-1])
return b
if __name__ == "__main__":
print('Numba : {}'.format(nb.__version__))
print('Numpy : {}\n'.format(np.__version__))
nx, nz = 1024, 1024
a = np.random.rand(nx, nz)
b = np.zeros_like(a)
range1(a, b, nx, nz)
range2(a, b, 5, nx-5, 5, nz-5)
range3(a, b, 5, nx-5, 5, nz-5)
Nit = 1000
ti = time.time()
for i in range(Nit):
range1(a, b, nx, nz)
print('range1 : {:.3f}'.format(time.time() - ti))
ti = time.time()
for i in range(Nit):
range2(a, b, 5, nx-5, 5, nz-5)
print('range2 : {:.3f}'.format(time.time() - ti))
ti = time.time()
for i in range(Nit):
range3(a, b, 5, nx-5, 5, nz-5)
print('range3 : {:.3f}'.format(time.time() - ti))
在 nopython
模式下编译的三个 'jitted' 函数几乎相同...除了范围参数。在我的笔记本电脑上,此代码 returns :
Numba : 0.37.0
Numpy : 1.14.2
range1 : 1.736 s.
range2 : 2.406 s.
range3 : 1.723 s.
如您所见,range1
和 range2
执行时间之间存在很大差异!经过一些测试,我得出以下结论:
- 当
range
函数的参数在要编译的函数中直接作为常量提供,或者是 等于 0 的变量(这种情况range1
和 range3
函数),性能是有的,相当不错!
- 另一方面,当
range
函数的参数是变量时,函数运行速度会慢 40% !
我认为这来自 numba 对 range
函数的编译。这导致了两个主要问题:
- 为什么?!
- 如何解决这个问题?
这里的问题似乎是环绕式索引语义。如果您将负数传递给例如b[ix, iz]
numpy 遵循 python 并将从数组轴的末尾开始索引。
这可以在 LLVM IR 中看到。 trim有很多噪音,我通过搜索fmul
指令找到了每个函数的内部循环。
# ir for first overload
ir = next(iter(range1.inspect_llvm().values()))
# range1 inner loop
B38.us: ; preds = %B38.lr.ph.us, %B38.us
%lsr.iv8 = phi i64 [ 0, %B38.lr.ph.us ], [ %lsr.iv.next9, %B38.us ]
%lsr.iv4 = phi i64 [ %lsr.iv2, %B38.lr.ph.us ], [ %lsr.iv.next5, %B38.us ]
%lsr.iv = phi i64 [ %17, %B38.lr.ph.us ], [ %lsr.iv.next, %B38.us ]
%31 = add i64 %lsr.iv10, %lsr.iv8
%.490.us = inttoptr i64 %31 to double*
%.491.us = load double, double* %.490.us, align 8
%32 = add i64 %lsr.iv6, %lsr.iv8
%.576.us = inttoptr i64 %32 to double*
%.577.us = load double, double* %.576.us, align 8
%.585.us = fsub double %.491.us, %.577.us
%.595.us = fmul double %.585.us, 5.000000e-01
%.659.us = inttoptr i64 %lsr.iv4 to double*
store double %.595.us, double* %.659.us, align 8
%lsr.iv.next = add nsw i64 %lsr.iv, -1
%lsr.iv.next5 = add i64 %lsr.iv4, %arg.b.6.1
%lsr.iv.next9 = add i64 %lsr.iv8, %arg.a.6.1
%.338.us = icmp sgt i64 %lsr.iv.next, 1
br i1 %.338.us, label %B38.us, label %B94.us
# range2 inner loop
B30.us: ; preds = %B30.lr.ph.us, %B30.us
%lsr.iv = phi i32 [ %1, %B30.lr.ph.us ], [ %lsr.iv.next, %B30.us ]
%.253.025.us = phi i32 [ %arg.iz1, %B30.lr.ph.us ], [ %.323.us, %B30.us ]
%.323.us = add i32 %.253.025.us, 1
%.400.us = sext i32 %.253.025.us to i64
%.401.us = add nsw i64 %.400.us, 1
%.441.us = icmp slt i32 %.253.025.us, -1
%.442.us = select i1 %.441.us, i64 %arg.a.5.1, i64 0
%.443.us = add i64 %.401.us, %.442.us
%.460.us = mul i64 %.443.us, %arg.a.6.1
%.463.us = add i64 %.461.us, %.460.us
%.464.us = inttoptr i64 %.463.us to double*
%.465.us = load double, double* %.464.us, align 8
%.489.us = add nsw i64 %.400.us, -1
%.529.us = icmp slt i32 %.253.025.us, 1
%.530.us = select i1 %.529.us, i64 %arg.a.5.1, i64 0
%.531.us = add i64 %.489.us, %.530.us
%.548.us = mul i64 %.531.us, %arg.a.6.1
%.551.us = add i64 %.549.us, %.548.us
%.552.us = inttoptr i64 %.551.us to double*
%.553.us = load double, double* %.552.us, align 8
%.561.us = fsub double %.465.us, %.553.us
%.571.us = fmul double %.561.us, 5.000000e-01
%.618.us = icmp slt i32 %.253.025.us, 0
%.619.us = select i1 %.618.us, i64 %arg.b.5.1, i64 0
%.620.us = add i64 %.619.us, %.400.us
%.637.us = mul i64 %.620.us, %arg.b.6.1
%.640.us = add i64 %.638.us, %.637.us
%.641.us = inttoptr i64 %.640.us to double*
store double %.571.us, double* %.641.us, align 8
%lsr.iv.next = add i32 %lsr.iv, -1
%.310.us = icmp sgt i32 %lsr.iv.next, 1
br i1 %.310.us, label %B30.us, label %B86.us
即使在那里,也有很多要解析的,但是在 range1
中只发生指针碰撞/查找/数学运算,而在 range2 中有边界检查(icmp
指令),因为编译器可以证明iz
永远不会为负
据我所知,除了像您那样从编译时间常量开始之外,目前没有办法省略它。曾经有一个用于启用/禁用的 wraparound
标志,但它是 removed
我刚刚意识到同时使用 jit
装饰器和 range
函数会出现奇怪的行为。比长篇大论更好,请考虑以下简单代码:
@nb.njit(['float64[:,:](float64[:,:], float64[:,:], int32, int32)'])
def range1(a, b, nx, nz):
for ix in range(5, nx-5):
for iz in range(5, nz-5):
b[ix, iz] = 0.5*(a[ix+1, iz+1] - a[ix-1, iz-1])
return b
@nb.njit(['float64[:,:](float64[:,:], float64[:,:], int32, int32, int32, int32)'])
def range2(a, b, ix1, ix2, iz1, iz2):
for ix in range(ix1, ix2):
for iz in range(iz1, iz2):
b[ix, iz] = 0.5*(a[ix+1, iz+1] - a[ix-1, iz-1])
return b
@nb.njit(['float64[:,:](float64[:,:], float64[:,:], int32, int32, int32, int32)'])
def range3(a, b, ix1, ix2, iz1, iz2):
for ix in range(ix1, ix2):
for iz in range(5, iz2):
b[ix, iz] = 0.5*(a[ix+1, iz+1] - a[ix-1, iz-1])
return b
if __name__ == "__main__":
print('Numba : {}'.format(nb.__version__))
print('Numpy : {}\n'.format(np.__version__))
nx, nz = 1024, 1024
a = np.random.rand(nx, nz)
b = np.zeros_like(a)
range1(a, b, nx, nz)
range2(a, b, 5, nx-5, 5, nz-5)
range3(a, b, 5, nx-5, 5, nz-5)
Nit = 1000
ti = time.time()
for i in range(Nit):
range1(a, b, nx, nz)
print('range1 : {:.3f}'.format(time.time() - ti))
ti = time.time()
for i in range(Nit):
range2(a, b, 5, nx-5, 5, nz-5)
print('range2 : {:.3f}'.format(time.time() - ti))
ti = time.time()
for i in range(Nit):
range3(a, b, 5, nx-5, 5, nz-5)
print('range3 : {:.3f}'.format(time.time() - ti))
在 nopython
模式下编译的三个 'jitted' 函数几乎相同...除了范围参数。在我的笔记本电脑上,此代码 returns :
Numba : 0.37.0
Numpy : 1.14.2
range1 : 1.736 s.
range2 : 2.406 s.
range3 : 1.723 s.
如您所见,range1
和 range2
执行时间之间存在很大差异!经过一些测试,我得出以下结论:
- 当
range
函数的参数在要编译的函数中直接作为常量提供,或者是 等于 0 的变量(这种情况range1
和range3
函数),性能是有的,相当不错! - 另一方面,当
range
函数的参数是变量时,函数运行速度会慢 40% !
我认为这来自 numba 对 range
函数的编译。这导致了两个主要问题:
- 为什么?!
- 如何解决这个问题?
这里的问题似乎是环绕式索引语义。如果您将负数传递给例如b[ix, iz]
numpy 遵循 python 并将从数组轴的末尾开始索引。
这可以在 LLVM IR 中看到。 trim有很多噪音,我通过搜索fmul
指令找到了每个函数的内部循环。
# ir for first overload
ir = next(iter(range1.inspect_llvm().values()))
# range1 inner loop
B38.us: ; preds = %B38.lr.ph.us, %B38.us
%lsr.iv8 = phi i64 [ 0, %B38.lr.ph.us ], [ %lsr.iv.next9, %B38.us ]
%lsr.iv4 = phi i64 [ %lsr.iv2, %B38.lr.ph.us ], [ %lsr.iv.next5, %B38.us ]
%lsr.iv = phi i64 [ %17, %B38.lr.ph.us ], [ %lsr.iv.next, %B38.us ]
%31 = add i64 %lsr.iv10, %lsr.iv8
%.490.us = inttoptr i64 %31 to double*
%.491.us = load double, double* %.490.us, align 8
%32 = add i64 %lsr.iv6, %lsr.iv8
%.576.us = inttoptr i64 %32 to double*
%.577.us = load double, double* %.576.us, align 8
%.585.us = fsub double %.491.us, %.577.us
%.595.us = fmul double %.585.us, 5.000000e-01
%.659.us = inttoptr i64 %lsr.iv4 to double*
store double %.595.us, double* %.659.us, align 8
%lsr.iv.next = add nsw i64 %lsr.iv, -1
%lsr.iv.next5 = add i64 %lsr.iv4, %arg.b.6.1
%lsr.iv.next9 = add i64 %lsr.iv8, %arg.a.6.1
%.338.us = icmp sgt i64 %lsr.iv.next, 1
br i1 %.338.us, label %B38.us, label %B94.us
# range2 inner loop
B30.us: ; preds = %B30.lr.ph.us, %B30.us
%lsr.iv = phi i32 [ %1, %B30.lr.ph.us ], [ %lsr.iv.next, %B30.us ]
%.253.025.us = phi i32 [ %arg.iz1, %B30.lr.ph.us ], [ %.323.us, %B30.us ]
%.323.us = add i32 %.253.025.us, 1
%.400.us = sext i32 %.253.025.us to i64
%.401.us = add nsw i64 %.400.us, 1
%.441.us = icmp slt i32 %.253.025.us, -1
%.442.us = select i1 %.441.us, i64 %arg.a.5.1, i64 0
%.443.us = add i64 %.401.us, %.442.us
%.460.us = mul i64 %.443.us, %arg.a.6.1
%.463.us = add i64 %.461.us, %.460.us
%.464.us = inttoptr i64 %.463.us to double*
%.465.us = load double, double* %.464.us, align 8
%.489.us = add nsw i64 %.400.us, -1
%.529.us = icmp slt i32 %.253.025.us, 1
%.530.us = select i1 %.529.us, i64 %arg.a.5.1, i64 0
%.531.us = add i64 %.489.us, %.530.us
%.548.us = mul i64 %.531.us, %arg.a.6.1
%.551.us = add i64 %.549.us, %.548.us
%.552.us = inttoptr i64 %.551.us to double*
%.553.us = load double, double* %.552.us, align 8
%.561.us = fsub double %.465.us, %.553.us
%.571.us = fmul double %.561.us, 5.000000e-01
%.618.us = icmp slt i32 %.253.025.us, 0
%.619.us = select i1 %.618.us, i64 %arg.b.5.1, i64 0
%.620.us = add i64 %.619.us, %.400.us
%.637.us = mul i64 %.620.us, %arg.b.6.1
%.640.us = add i64 %.638.us, %.637.us
%.641.us = inttoptr i64 %.640.us to double*
store double %.571.us, double* %.641.us, align 8
%lsr.iv.next = add i32 %lsr.iv, -1
%.310.us = icmp sgt i32 %lsr.iv.next, 1
br i1 %.310.us, label %B30.us, label %B86.us
即使在那里,也有很多要解析的,但是在 range1
中只发生指针碰撞/查找/数学运算,而在 range2 中有边界检查(icmp
指令),因为编译器可以证明iz
永远不会为负
据我所知,除了像您那样从编译时间常量开始之外,目前没有办法省略它。曾经有一个用于启用/禁用的 wraparound
标志,但它是 removed