在 numba.jit 装饰器中使用并行选项会使函数给出错误的结果
Usage of parallel option in numba.jit decoratior makes function give wrong result
给定矩形的两个对角 (x1, y1)
和 (x2, y2)
以及两个半径 r1
和 r2
,求位于由半径 r1
和 r2
到矩形中点的总数。
简单的 NumPy 方法:
def func_1(x1,y1,x2,y2,r1,r2,n):
x11,y11 = np.meshgrid(np.linspace(x1,x2,n),np.linspace(y1,y2,n))
z1 = np.sqrt(x11**2+y11**2)
a = np.where((z1>(r1)) & (z1<(r2)))
fill_factor = len(a[0])/(n*n)
return fill_factor
接下来我尝试使用 numba 的 jit
装饰器优化这个函数。当我使用:
nopython = True
函数速度更快,输出正确。但是当我还添加:
parallel = True
该函数速度更快但给出了错误的结果。
我知道这与我的 z
矩阵有关,因为它没有正确更新。
@jit(nopython=True,parallel=True)
def func_2(x1,y1,x2,y2,r1,r2,n):
x_ = np.linspace(x1,x2,n)
y_ = np.linspace(y1,y2,n)
z1 = np.zeros((n,n))
for i in range(n):
for j in range(n):
z1[i][j] = np.sqrt((x_[i]*x_[i]+y_[j]*y_[j]))
a = np.where((z1>(r1)) & (z1<(r2)))
fill_factor = len(a[0])/(n*n)
return fill_factor
测试值:
x1 = 1.0
x2 = -1.0
y1 = 1.0
y2 = -1.0
r1 = 0.5
r2 = 0.75
n = 25000
其他信息:Python 版本:3.6.1,Numba 版本:0.34.0+5.g1762237,NumPy 版本:1.13.1
parallel=True
的问题在于它是一个黑匣子。 Numba 甚至不保证它实际上会并行化任何东西。它使用启发式方法来确定它是否可并行化以及可以 并行完成的内容。这些可能会失败,在您的示例中它们确实会失败,就像 my experiments with parallel
and numba 中一样。这使得 parallel
不可信,我建议 反对 使用它!
在较新的版本 (0.34) 中添加了 prange
,您可能会更幸运。它不能在这种情况下应用,因为 prange
的工作方式类似于 range
而不同于 np.linspace
...
请注意:您可以完全避免构建 z
并在您的函数中执行 np.where
,您可以明确地进行检查:
import numpy as np
import numba as nb
@nb.njit # equivalent to "jit(nopython=True)".
def func_2(x1,y1,x2,y2,r1,r2,n):
x_ = np.linspace(x1,x2,n)
y_ = np.linspace(y1,y2,n)
cnts = 0
for i in range(n):
for j in range(n):
z = np.sqrt(x_[i] * x_[i] + y_[j] * y_[j])
if r1 < z < r2:
cnts += 1
fill_factor = cnts/(n*n)
return fill_factor
与您的函数相比,这也应该提供一些加速,甚至可能比使用 parallel=True
更快(如果它能正常工作)。
给定矩形的两个对角 (x1, y1)
和 (x2, y2)
以及两个半径 r1
和 r2
,求位于由半径 r1
和 r2
到矩形中点的总数。
简单的 NumPy 方法:
def func_1(x1,y1,x2,y2,r1,r2,n):
x11,y11 = np.meshgrid(np.linspace(x1,x2,n),np.linspace(y1,y2,n))
z1 = np.sqrt(x11**2+y11**2)
a = np.where((z1>(r1)) & (z1<(r2)))
fill_factor = len(a[0])/(n*n)
return fill_factor
接下来我尝试使用 numba 的 jit
装饰器优化这个函数。当我使用:
nopython = True
函数速度更快,输出正确。但是当我还添加:
parallel = True
该函数速度更快但给出了错误的结果。
我知道这与我的 z
矩阵有关,因为它没有正确更新。
@jit(nopython=True,parallel=True)
def func_2(x1,y1,x2,y2,r1,r2,n):
x_ = np.linspace(x1,x2,n)
y_ = np.linspace(y1,y2,n)
z1 = np.zeros((n,n))
for i in range(n):
for j in range(n):
z1[i][j] = np.sqrt((x_[i]*x_[i]+y_[j]*y_[j]))
a = np.where((z1>(r1)) & (z1<(r2)))
fill_factor = len(a[0])/(n*n)
return fill_factor
测试值:
x1 = 1.0
x2 = -1.0
y1 = 1.0
y2 = -1.0
r1 = 0.5
r2 = 0.75
n = 25000
其他信息:Python 版本:3.6.1,Numba 版本:0.34.0+5.g1762237,NumPy 版本:1.13.1
parallel=True
的问题在于它是一个黑匣子。 Numba 甚至不保证它实际上会并行化任何东西。它使用启发式方法来确定它是否可并行化以及可以 并行完成的内容。这些可能会失败,在您的示例中它们确实会失败,就像 my experiments with parallel
and numba 中一样。这使得 parallel
不可信,我建议 反对 使用它!
在较新的版本 (0.34) 中添加了 prange
,您可能会更幸运。它不能在这种情况下应用,因为 prange
的工作方式类似于 range
而不同于 np.linspace
...
请注意:您可以完全避免构建 z
并在您的函数中执行 np.where
,您可以明确地进行检查:
import numpy as np
import numba as nb
@nb.njit # equivalent to "jit(nopython=True)".
def func_2(x1,y1,x2,y2,r1,r2,n):
x_ = np.linspace(x1,x2,n)
y_ = np.linspace(y1,y2,n)
cnts = 0
for i in range(n):
for j in range(n):
z = np.sqrt(x_[i] * x_[i] + y_[j] * y_[j])
if r1 < z < r2:
cnts += 1
fill_factor = cnts/(n*n)
return fill_factor
与您的函数相比,这也应该提供一些加速,甚至可能比使用 parallel=True
更快(如果它能正常工作)。