Numba 通过就地影响破坏数据
Numba corrupts data by affecting in-place
Numba 和 NumPy 不会以相同的方式执行以下 foo
函数:
from numba import jit
import numpy as np
@jit
def foo(a):
a[:] = a[::-1] # reverse the array
a = np.array([0, 1, 2])
foo(a)
print(a)
使用 NumPy(没有 @jit
)打印 [2, 1, 0]
,而使用 Numba(使用 @jit
)打印 [2, 1, 2]
。看起来 Numba 就地修改了数组,这会导致数据损坏。通过复制数组很容易解决:
a[:] = a[::-1].copy()
但这是期望的行为吗? Numba 和 NumPy 不应该给出相同的结果吗?
我在 Python 3.5.2.
中使用 Numba v0.26.0
您的 jit
存在与此 Python 循环相同的就地问题。
In [718]: x=list(range(3))
In [719]: for i in range(3):
...: x[i] = x[2-i]
In [720]: x
Out[720]: [2, 1, 2]
x[:] = x[::-1]
是缓冲的,不是因为 numpy 认识到发生了一些特殊的事情,而是因为它总是在做赋值时使用某种缓冲。
Python 解释器将 []
符号转换为对 __setitem__
和 __getitem__
的调用。所以 681 和 682 做同样的事情:
In [680]: x=np.arange(3)
In [681]: x[:] = x[::-1]
In [682]: x.__setitem__(slice(None), x.__getitem__(slice(None,None,-1)))
In [683]: x
Out[683]: array([0, 1, 2])
这意味着 x[::-1]
在被复制到 x[:]
之前被完全评估为一个临时数组。现在 x[::-1]
是视图,而不是副本,因此 setitem
步骤必须执行某种缓冲复制。
另一种复制方式是
np.copyto(x, x[::-1])
检查 x.__array_interface__
我发现数据缓冲区地址保持不变。所以它是在做一个拷贝,而不只是改变数据缓冲区地址。但它是低级编译代码。
通常缓冲只是一个用户不需要担心的实现问题。 ufunc.at 旨在处理缓冲产生问题的情况。这个话题会定期出现;搜索 add.at
.
=============
请注意 Python 列表的行为方式相同。 'get/setitem' 的翻译是一样的。
In [699]: x=list(range(3))
In [700]: x[:] = x[::-1]
In [701]: x
Out[701]: [2, 1, 0]
======================
我不完全确定这是否相关,但由于我测试了这些想法,所以我会记录它们。 https://docs.scipy.org/doc/numpy/reference/arrays.nditer.html 建议使用 np.nditer
作为在 cython
中实现迭代任务的垫脚石。
使用 nditer
的第一次尝试是:
In [769]: x=np.arange(5)
In [770]: it = np.nditer((x,x[::-1]), op_flags=[['readwrite'], ['readonly']])
In [771]: for i,j in it:
...: print(i,j)
...: i[...] = j
...:
0 4
1 3
2 2
3 3
4 4
In [772]: x
Out[772]: array([4, 3, 2, 3, 4])
这会产生与 numba
相同类型的重叠结果。
添加一个副本可以完全反转。
it = np.nditer((x,x[::-1].copy()), op_flags=[['readwrite'], ['readonly']])
如果我添加 external_loop
标志,我也会得到一个干净的反转:
In [781]: x=np.arange(5)
In [782]: it = np.nditer((x,x[::-1]), op_flags=[['readwrite'], ['readonly']], fl
...: ags = ['external_loop'])
In [783]: for i,j in it:
...: print(i,j)
...: i[...] = j
...:
[0 1 2 3 4] [4 3 2 1 0]
In [784]: x
Out[784]: array([4, 3, 2, 1, 0])
这是一个已知问题 (https://github.com/numba/numba/issues/1960),已在 numba 0.27 中修复。按照 NumPy 行为,此修复程序会检测重叠并制作临时副本以避免损坏数据。
Numba 和 NumPy 不会以相同的方式执行以下 foo
函数:
from numba import jit
import numpy as np
@jit
def foo(a):
a[:] = a[::-1] # reverse the array
a = np.array([0, 1, 2])
foo(a)
print(a)
使用 NumPy(没有 @jit
)打印 [2, 1, 0]
,而使用 Numba(使用 @jit
)打印 [2, 1, 2]
。看起来 Numba 就地修改了数组,这会导致数据损坏。通过复制数组很容易解决:
a[:] = a[::-1].copy()
但这是期望的行为吗? Numba 和 NumPy 不应该给出相同的结果吗?
我在 Python 3.5.2.
中使用 Numba v0.26.0您的 jit
存在与此 Python 循环相同的就地问题。
In [718]: x=list(range(3))
In [719]: for i in range(3):
...: x[i] = x[2-i]
In [720]: x
Out[720]: [2, 1, 2]
x[:] = x[::-1]
是缓冲的,不是因为 numpy 认识到发生了一些特殊的事情,而是因为它总是在做赋值时使用某种缓冲。
Python 解释器将 []
符号转换为对 __setitem__
和 __getitem__
的调用。所以 681 和 682 做同样的事情:
In [680]: x=np.arange(3)
In [681]: x[:] = x[::-1]
In [682]: x.__setitem__(slice(None), x.__getitem__(slice(None,None,-1)))
In [683]: x
Out[683]: array([0, 1, 2])
这意味着 x[::-1]
在被复制到 x[:]
之前被完全评估为一个临时数组。现在 x[::-1]
是视图,而不是副本,因此 setitem
步骤必须执行某种缓冲复制。
另一种复制方式是
np.copyto(x, x[::-1])
检查 x.__array_interface__
我发现数据缓冲区地址保持不变。所以它是在做一个拷贝,而不只是改变数据缓冲区地址。但它是低级编译代码。
通常缓冲只是一个用户不需要担心的实现问题。 ufunc.at 旨在处理缓冲产生问题的情况。这个话题会定期出现;搜索 add.at
.
=============
请注意 Python 列表的行为方式相同。 'get/setitem' 的翻译是一样的。
In [699]: x=list(range(3))
In [700]: x[:] = x[::-1]
In [701]: x
Out[701]: [2, 1, 0]
======================
我不完全确定这是否相关,但由于我测试了这些想法,所以我会记录它们。 https://docs.scipy.org/doc/numpy/reference/arrays.nditer.html 建议使用 np.nditer
作为在 cython
中实现迭代任务的垫脚石。
使用 nditer
的第一次尝试是:
In [769]: x=np.arange(5)
In [770]: it = np.nditer((x,x[::-1]), op_flags=[['readwrite'], ['readonly']])
In [771]: for i,j in it:
...: print(i,j)
...: i[...] = j
...:
0 4
1 3
2 2
3 3
4 4
In [772]: x
Out[772]: array([4, 3, 2, 3, 4])
这会产生与 numba
相同类型的重叠结果。
添加一个副本可以完全反转。
it = np.nditer((x,x[::-1].copy()), op_flags=[['readwrite'], ['readonly']])
如果我添加 external_loop
标志,我也会得到一个干净的反转:
In [781]: x=np.arange(5)
In [782]: it = np.nditer((x,x[::-1]), op_flags=[['readwrite'], ['readonly']], fl
...: ags = ['external_loop'])
In [783]: for i,j in it:
...: print(i,j)
...: i[...] = j
...:
[0 1 2 3 4] [4 3 2 1 0]
In [784]: x
Out[784]: array([4, 3, 2, 1, 0])
这是一个已知问题 (https://github.com/numba/numba/issues/1960),已在 numba 0.27 中修复。按照 NumPy 行为,此修复程序会检测重叠并制作临时副本以避免损坏数据。