Numba 通过就地影响破坏数据

Numba corrupts data by affecting in-place

Numba 和 NumPy 不会以相同的方式执行以下 foo 函数:

from numba import jit
import numpy as np

@jit
def foo(a):
    a[:] = a[::-1] # reverse the array

a = np.array([0, 1, 2])
foo(a)
print(a)

使用 NumPy(没有 @jit)打印 [2, 1, 0],而使用 Numba(使用 @jit)打印 [2, 1, 2]。看起来 Numba 就地修改了数组,这会导致数据损坏。通过复制数组很容易解决:

a[:] = a[::-1].copy()

但这是期望的行为吗? Numba 和 NumPy 不应该给出相同的结果吗?

我在 Python 3.5.2.

中使用 Numba v0.26.0

您的 jit 存在与此 Python 循环相同的就地问题。

In [718]: x=list(range(3))
In [719]: for i in range(3):
     ...:     x[i] = x[2-i]
In [720]: x
Out[720]: [2, 1, 2]

x[:] = x[::-1] 是缓冲的,不是因为 numpy 认识到发生了一些特殊的事情,而是因为它总是在做赋值时使用某种缓冲。

Python 解释器将 [] 符号转换为对 __setitem____getitem__ 的调用。所以 681 和 682 做同样的事情:

In [680]: x=np.arange(3)
In [681]: x[:] = x[::-1]
In [682]: x.__setitem__(slice(None), x.__getitem__(slice(None,None,-1)))
In [683]: x
Out[683]: array([0, 1, 2])

这意味着 x[::-1] 在被复制到 x[:] 之前被完全评估为一个临时数组。现在 x[::-1] 是视图,而不是副本,因此 setitem 步骤必须执行某种缓冲复制。

另一种复制方式是

np.copyto(x, x[::-1])

检查 x.__array_interface__ 我发现数据缓冲区地址保持不变。所以它是在做一个拷贝,而不只是改变数据缓冲区地址。但它是低级编译代码。

通常缓冲只是一个用户不需要担心的实现问题。 ufunc.at 旨在处理缓冲产生问题的情况。这个话题会定期出现;搜索 add.at.

=============

请注意 Python 列表的行为方式相同。 'get/setitem' 的翻译是一样的。

In [699]: x=list(range(3))
In [700]: x[:] = x[::-1]
In [701]: x
Out[701]: [2, 1, 0]

======================

我不完全确定这是否相关,但由于我测试了这些想法,所以我会记录它们。 https://docs.scipy.org/doc/numpy/reference/arrays.nditer.html 建议使用 np.nditer 作为在 cython 中实现迭代任务的垫脚石。

使用 nditer 的第一次尝试是:

In [769]: x=np.arange(5)
In [770]: it = np.nditer((x,x[::-1]), op_flags=[['readwrite'], ['readonly']])
In [771]: for i,j in it:
     ...:     print(i,j)
     ...:     i[...] = j
     ...:     
0 4
1 3
2 2
3 3
4 4
In [772]: x
Out[772]: array([4, 3, 2, 3, 4])

这会产生与 numba 相同类型的重叠结果。

添加一个副本可以完全反转。

it = np.nditer((x,x[::-1].copy()), op_flags=[['readwrite'], ['readonly']])

如果我添加 external_loop 标志,我也会得到一个干净的反转:

In [781]: x=np.arange(5)
In [782]: it = np.nditer((x,x[::-1]), op_flags=[['readwrite'], ['readonly']], fl
     ...: ags = ['external_loop'])
In [783]: for i,j in it:
     ...:     print(i,j)
     ...:     i[...] = j
     ...:     
[0 1 2 3 4] [4 3 2 1 0]
In [784]: x
Out[784]: array([4, 3, 2, 1, 0])

这是一个已知问题 (https://github.com/numba/numba/issues/1960),已在 numba 0.27 中修复。按照 NumPy 行为,此修复程序会检测重叠并制作临时副本以避免损坏数据。