为什么 numba 和 numpy 在就地操作上表现不同?
Why numba and numpy performs differently on inplace operations?
用 numba njit
修饰的函数,它是 jit(nopython=True)
的别名,与 numpy
就地操作(简单的 @jit(nopython=False)
也给出与 numpy
不同的结果):
In [1]: import numpy as np
from numba import njit
def npfun(arr):
arr -= arr[3]
@njit
def jitfun(arr):
arr -= arr[3]
arr1 = np.ones((6,2))
arr2 = arr1.copy()
npfun(arr1)
jitfun(arr2)
arr1 == arr2
Out[1]: array([[ True, True],
[ True, True],
[ True, True],
[ True, True],
[False, False],
[False, False]], dtype=bool)
看起来 numpy
评估 rhs 并将其作为副本传递,而 numba
对待 rhs作为一种观点。这样做有什么技术原因吗?
numpy 1.13.3
numba 0.35
解决方法
如果您使用带有 out
参数的显式就地操作而不是扩充赋值,它将对 Numba 有所帮助:替换
arr -= arr[3]
和
np.subtract(arr, arr[3], out=arr)
这使得 jitted 版本(尽管使用 @jit,而不是 @njit)与 NumPy 版本的性能相同。值得注意的是,@njit 这个函数的尝试现在会失败,告诉你 njit 无法处理 out
参数。
我已经看到你了 opened an issue on this,因此有可能更改增强赋值的处理以匹配 NumPy。
为什么会这样
正如 hpaulj 所说,Numba 输出相当于遍历数组的行。 Python JITter 不能对 NumPy 底层的 C 代码做任何事情;它需要 Python 才能使用。支持 NumPy 方法(在某种程度上)的原因是 Numba 开发人员遇到了将 NumPy 数组操作解码为标量对象上的显式 Python 迭代的麻烦,然后可以将其传递给 LLVM。来自 documentation:
- Synthesize a Python function that implements the array expression: This new Python function essentially behaves like a Numpy ufunc, returning the result of the expression on scalar values in the broadcasted array arguments. The lowering function accomplishes this by translating from the array expression tree into a Python AST.
- Compile the synthetic Python function into a kernel: At this point, the lowering function relies on existing code for lowering ufunc and DUFunc kernels, calling
numba.targets.numpyimpl.numpy_ufunc_kernel()
after defining how to lower calls to the synthetic function.
The end result is similar to loop lifting in Numba’s object mode.
前面提到的 numpy_ufunc_kernel
收集索引并 对它们进行迭代 。如果被迭代的对象在迭代过程中发生了变异,那就让事情变得棘手了。
您正在执行的操作:
arr -= arr[3]
曾经是 NumPy 中的未定义行为。它最近才在 NumPy 1.13.0, released June 7th this year 中被定义。新定义的行为总是像复制所有输入一样,尽管它会在检测到不需要时尝试避免实际制作副本。
看起来 Numba 目前没有尝试模仿新行为,要么是因为它太新了,要么是因为 Numba 特定的问题。
用 numba njit
修饰的函数,它是 jit(nopython=True)
的别名,与 numpy
就地操作(简单的 @jit(nopython=False)
也给出与 numpy
不同的结果):
In [1]: import numpy as np
from numba import njit
def npfun(arr):
arr -= arr[3]
@njit
def jitfun(arr):
arr -= arr[3]
arr1 = np.ones((6,2))
arr2 = arr1.copy()
npfun(arr1)
jitfun(arr2)
arr1 == arr2
Out[1]: array([[ True, True],
[ True, True],
[ True, True],
[ True, True],
[False, False],
[False, False]], dtype=bool)
看起来 numpy
评估 rhs 并将其作为副本传递,而 numba
对待 rhs作为一种观点。这样做有什么技术原因吗?
numpy 1.13.3
numba 0.35
解决方法
如果您使用带有 out
参数的显式就地操作而不是扩充赋值,它将对 Numba 有所帮助:替换
arr -= arr[3]
和
np.subtract(arr, arr[3], out=arr)
这使得 jitted 版本(尽管使用 @jit,而不是 @njit)与 NumPy 版本的性能相同。值得注意的是,@njit 这个函数的尝试现在会失败,告诉你 njit 无法处理 out
参数。
我已经看到你了 opened an issue on this,因此有可能更改增强赋值的处理以匹配 NumPy。
为什么会这样
正如 hpaulj 所说,Numba 输出相当于遍历数组的行。 Python JITter 不能对 NumPy 底层的 C 代码做任何事情;它需要 Python 才能使用。支持 NumPy 方法(在某种程度上)的原因是 Numba 开发人员遇到了将 NumPy 数组操作解码为标量对象上的显式 Python 迭代的麻烦,然后可以将其传递给 LLVM。来自 documentation:
- Synthesize a Python function that implements the array expression: This new Python function essentially behaves like a Numpy ufunc, returning the result of the expression on scalar values in the broadcasted array arguments. The lowering function accomplishes this by translating from the array expression tree into a Python AST.
- Compile the synthetic Python function into a kernel: At this point, the lowering function relies on existing code for lowering ufunc and DUFunc kernels, calling
numba.targets.numpyimpl.numpy_ufunc_kernel()
after defining how to lower calls to the synthetic function.The end result is similar to loop lifting in Numba’s object mode.
前面提到的 numpy_ufunc_kernel
收集索引并 对它们进行迭代 。如果被迭代的对象在迭代过程中发生了变异,那就让事情变得棘手了。
您正在执行的操作:
arr -= arr[3]
曾经是 NumPy 中的未定义行为。它最近才在 NumPy 1.13.0, released June 7th this year 中被定义。新定义的行为总是像复制所有输入一样,尽管它会在检测到不需要时尝试避免实际制作副本。
看起来 Numba 目前没有尝试模仿新行为,要么是因为它太新了,要么是因为 Numba 特定的问题。