在 numba jitted 函数中动态增长数组

dynamically growing array in numba jitted functions

numba 似乎不支持 numpy.resize

在 nopython 模式下 numba.jit 使用动态增长数组的最佳方法是什么?

到目前为止,我能做的最好的事情就是在 jitted 函数之外定义数组并调整其大小,是否有更好(更整洁)的选择?

我通常采用的策略是分配足够多的数组存储空间来容纳计算,然后跟踪最终使用的 index/indices,然后在返回之前将数组切分到实际大小。这假设您事先知道您可以将数组增长到的最大大小。我的想法是,在我自己的大多数应用程序中,内存很便宜,但在 python 和 jitted 函数之间调整大小和切换很多很昂贵。

numpy.resize 是一个 pure python function:

import numpy as np

def resize(a, new_shape):
    """I did some minor changes so it all works with just `import numpy as np`."""
    if isinstance(new_shape, (int, np.core.numerictypes.integer)):
        new_shape = (new_shape,)
    a = np.ravel(a)
    Na = len(a)
    if not Na:
        return np.zeros(new_shape, a.dtype)
    total_size = np.multiply.reduce(new_shape)
    n_copies = int(total_size / Na)
    extra = total_size % Na

    if total_size == 0:
        return a[:0]

    if extra != 0:
        n_copies = n_copies+1
        extra = Na-extra

    a = np.concatenate((a,)*n_copies)
    if extra > 0:
        a = a[:-extra]

    return np.reshape(a, new_shape)

对于一维数组,您可以直接实现。不幸的是,ND 数组要复杂得多,因为 nopython numba 函数不支持某些操作:isinstancereshape 和元组乘法。这是一维等效项:

import numpy as np
import numba as nb

@nb.njit
def resize(a, new_size):
    new = np.zeros(new_size, a.dtype)
    idx = 0
    while True:
        newidx = idx + a.size
        if newidx > new_size:
            new[idx:] = a[:new_size-newidx]
            break
        new[idx:newidx] = a
        idx = newidx
    return new

如果您不想要 "repeat input" 行为并且仅将其用于 增加大小 则更容易:

@nb.njit
def resize(a, new_size):
    new = np.zeros(new_size, a.dtype)
    new[:a.size] = a
    return new

这些函数用 numba.njit 装饰,因此可以在 nopython 模式下的任何 numba 函数中调用。


但请注意:通常您不想调整大小 - 或者如果您这样做,请确保您选择的方法具有 amoritzed O(1) cost (Wikipedia link)。如果您可以估计最大长度,那么最好立即预分配一个大小正确(或稍微过度分配)的数组。

要动态增加现有数组的大小(因此就地进行),必须使用 numpy.ndarray.resize 而不是 numpy.resize。此方法未在 Python 中实现,并且在 Numba 中不可用,因此无法完成。