numba np.diff 有错误吗?

Does numba np.diff have a bug?

我遇到了这个问题,np.diff 的 numba 实现在矩阵的一部分上不起作用。这是一个错误还是我做错了什么?

import numpy as np
from numba import njit
v = np.ones((2,2))
np.diff(v[:,0])
array([0.])
@njit
def numbadiff(x):
    return np.diff(x)

numbadiff(v[:,0])

最后一次调用 returns 出错了,但我不确定为什么。

问题是 Numba 中的 np.diff 进行内部重塑,仅 contiguous arrays. The slice that you are making, v[:, 0], is not contiguous, hence the error. You can get it to work using np.ascontiguousarray 支持,returns 给定数组的连续副本(如果尚未存在)连续:

numbadiff(np.ascontiguousarray(v[:, 0]))

请注意,您也可以避免 np.diff 并将 numbadiff 重新定义为:

@njit
def numbadiff(x):
    return x[1:] - x[:-1]

当你遇到错误时,礼貌的做法是显示错误。有时带有回溯的完整错误是合适的。对于numba 说的可能太多了,但是你应该尽量post 一个总结。它使我们更容易,特别是如果我们无法 运行 您的代码并自己查看错误。你甚至可能学到一些东西。

我运行你的例子得到了(部分):

In [428]: numbadiff(np.ones((2,2))[:,0])                                        
---------------------------------------------------------------------------
TypingError    
...
TypeError: reshape() supports contiguous array only
...
    def diff_impl(a, n=1):
        <source elided>
        # To make things easier, normalize input and output into 2d arrays
        a2 = a.reshape((-1, size))
...
TypeError: reshape() supports contiguous array only
....
This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.

这支持@jdehesa 提供的诊断和修复。这不是 numba 中的错误;是你输入的问题。

使用 numba 的一个缺点是错误更难理解。另一个显然是它对输入(例如此数组视图)不太灵活。如果你真的想要速度优势,你需要愿意自己深入研究错误信息。