`numba` 和 `numpy.concatenate`

`numba` and `numpy.concatenate`

我正在尝试使用 numba 来加速一些代码,但这很难。例如下面的函数不会numba-fy,

@jit(nopython=True)
def returns(Ft, x, delta):
    T = len(x)
    rets = Ft[0:T - 1] * x[1:T] - delta * np.abs(Ft[1:T] - Ft[0:T - 1])
    return np.concatenate([[0], rets])

因为numba找不到np.concatenate的签名。对此有规范的修复吗?

有点晚了,但我希望仍然有用。既然你要求“规范修复”,我想解释为什么 concatenate 在使用数组时是个坏主意,特别是如果你表示你想消除瓶颈并因此使用 numba jit。数组是内存中连续的字节序列(numpy 知道一些技巧可以在不通过创建视图进行复制的情况下更改顺序,但这是另一个话题,请参阅 https://towardsdatascience.com/advanced-numpy-master-stride-tricks-with-25-illustrated-exercises-923a9393ab20)。如果要将值 x 添加到包含 N 个元素的数组中,则需要创建一个包含 N+1 个元素的新数组,将第一个值设置为 x 并复制剩余部分。作为旁注,类似的论点适用于将项目添加到 python 列表,这就是 collections.deque 存在的原因。

现在,在您的 jit 装饰函数中,您可能希望编译器理解您想要做什么,但是编写始终理解您尝试做的事情的编译器几乎是不可能的。因此,最好善待编译器,并在您知道正确的选择时帮助进行内存布局。因此,恕我直言,您的示例代码的“规范修复”将类似于以下内容:

@jit(nopython=True)
def returns(Ft, x, delta):
    T = len(x)
    rets = np.empty_like(x)
    rets[0] = 0
    rets[1:T] = Ft[0:T - 1] * x[1:T] - delta * np.abs(Ft[1:T] - Ft[0:T - 1])
    return rets

总的来说,我同意@Aaron 的评论,这意味着对于在 jit 装饰函数中调用的任何函数,您应该始终尽可能明确地指定输入类型。在您的情况下,作为编译器问自己“什么是 [[0], rets]?”。考虑严格类型,您会看到一个列表,其中包含一个整数列表和一个浮点(或复数)数组。对于编译器来说,这是一种具有挑战性的类型混合。输出应该变成整数数组还是浮点数数组?