理解 numba 并行化中的这种竞争条件
understanding this race condition in numba parallelization
Numba 文档中有一个关于并行竞争条件的示例
import numba as nb
import numpy as np
@nb.njit(parallel=True)
def prange_wrong_result(x):
n = x.shape[0]
y = np.zeros(4)
for i in nb.prange(n):
y[:]+= x[i]
return y
我有运行,确实输出异常
prange_wrong_result(np.ones(10000))
#array([5264., 5273., 5231., 5234.])
然后我尝试将循环更改为
import numba as nb
import numpy as np
@nb.njit(parallel=True)
def prange_wrong_result(x):
n = x.shape[0]
y = np.zeros(4)
for i in nb.prange(n):
y+= x[i]
return y
并输出
prange_wrong_result(np.ones(10000))
#array([10000., 10000., 10000., 10000.])
我已经阅读了一些竞争条件解释。但是还是不懂
- 为什么第二个例子没有赛车条件?
y[:]=
和 y=
有什么区别
- 为什么第一个例子中四个元素的输出不一样?
在您的第一个示例中,您有多个 threads/processes 共享同一个数组并读取并分配给共享数组。 y[:] += x[i]
大致相当于:
y[0] += x[i]
y[1] += x[i]
y[2] += x[i]
y[3] += x[i]
实际上 +=
只是读取、加法和赋值操作的语法糖,所以 y[0] += x[i]
实际上是:
_value = y[0]
_value = _value + x[i]
y[0] = _value
循环体由多个 threads/processes 同时执行,这就是竞争条件的来源。维基百科上关于竞争条件的示例适用于此处:
这就是返回的数组包含错误值以及每个元素可能不同的原因。因为 thread/process 运行的时间是不确定的。所以在某些情况下,一个元素上存在竞争条件,有时在 none 上,有时在多个元素上。
然而,numba 开发人员已经实施了一些支持的减少,其中没有出现竞争条件。其中之一是 y +=
。这里重要的是它是变量本身,而不是变量的 slice/element 。在那种情况下,numba 做了一些非常聪明的事情。他们为每个 thread/process 复制变量的初始值,然后对该副本进行操作。并行循环完成后,他们将复制的值相加。以你的第二个例子为例,假设它使用 2 个进程,它看起来大致像这样:
y = np.zeros(4)
y_1 = y.copy()
y_2 = y.copy()
for i in nb.prange(n):
if is_process_1:
y_1[:] += x[i]
if is_process_2:
y_2[:] += x[i]
y += y_1
y += y_2
由于每个线程都有自己的数组,因此不会出现竞争条件。为了让 numba 能够推断出这一点,你必须遵守他们的限制。文档指出 numba 在标量和数组 (y += x[i]
) 上为 +=
创建无竞争条件的并行代码,但 不在数组 elements/slices 上(y[:] += x[i]
或 y[1] += x[i]
)。
Numba 文档中有一个关于并行竞争条件的示例
import numba as nb
import numpy as np
@nb.njit(parallel=True)
def prange_wrong_result(x):
n = x.shape[0]
y = np.zeros(4)
for i in nb.prange(n):
y[:]+= x[i]
return y
我有运行,确实输出异常
prange_wrong_result(np.ones(10000))
#array([5264., 5273., 5231., 5234.])
然后我尝试将循环更改为
import numba as nb
import numpy as np
@nb.njit(parallel=True)
def prange_wrong_result(x):
n = x.shape[0]
y = np.zeros(4)
for i in nb.prange(n):
y+= x[i]
return y
并输出
prange_wrong_result(np.ones(10000))
#array([10000., 10000., 10000., 10000.])
我已经阅读了一些竞争条件解释。但是还是不懂
- 为什么第二个例子没有赛车条件?
y[:]=
和y=
有什么区别
- 为什么第一个例子中四个元素的输出不一样?
在您的第一个示例中,您有多个 threads/processes 共享同一个数组并读取并分配给共享数组。 y[:] += x[i]
大致相当于:
y[0] += x[i]
y[1] += x[i]
y[2] += x[i]
y[3] += x[i]
实际上 +=
只是读取、加法和赋值操作的语法糖,所以 y[0] += x[i]
实际上是:
_value = y[0]
_value = _value + x[i]
y[0] = _value
循环体由多个 threads/processes 同时执行,这就是竞争条件的来源。维基百科上关于竞争条件的示例适用于此处:
这就是返回的数组包含错误值以及每个元素可能不同的原因。因为 thread/process 运行的时间是不确定的。所以在某些情况下,一个元素上存在竞争条件,有时在 none 上,有时在多个元素上。
然而,numba 开发人员已经实施了一些支持的减少,其中没有出现竞争条件。其中之一是 y +=
。这里重要的是它是变量本身,而不是变量的 slice/element 。在那种情况下,numba 做了一些非常聪明的事情。他们为每个 thread/process 复制变量的初始值,然后对该副本进行操作。并行循环完成后,他们将复制的值相加。以你的第二个例子为例,假设它使用 2 个进程,它看起来大致像这样:
y = np.zeros(4)
y_1 = y.copy()
y_2 = y.copy()
for i in nb.prange(n):
if is_process_1:
y_1[:] += x[i]
if is_process_2:
y_2[:] += x[i]
y += y_1
y += y_2
由于每个线程都有自己的数组,因此不会出现竞争条件。为了让 numba 能够推断出这一点,你必须遵守他们的限制。文档指出 numba 在标量和数组 (y += x[i]
) 上为 +=
创建无竞争条件的并行代码,但 不在数组 elements/slices 上(y[:] += x[i]
或 y[1] += x[i]
)。