在 numba 中用数组索引数组

Index an array with an array in numba

作为我想编译的更复杂函数的一部分 numba,我必须用另一个数组 idx.

索引一个数组 A

重要的是,数组的维数A是可变的。 形状可以是(N)(N,N)(N,N,N)

在 python 中,我可以使用元组来做到这一点:

def test():
    A = np.arange(5*5*5).reshape(5,5,5)
    idx = np.array([0,2,4])
    return A[tuple(idx)]

但是,显然不支持使用元组索引数组 numba 因为我收到以下错误:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<class 'tuple'>) found for signature:
 
 >>> tuple(array(int64, 1d, C))

请注意,这只是一个最小的工作示例。一般来说,我不 知道 idx.

的长度

我想将 A 重塑为向量并转换 idx 到相应的标量索引。

这是最好的解决方案还是有一个简单的“numba 替代方案” 用元组索引?

Importantly, the dimension of the array A is variable. The shape can be (N), (N,N), or (N,N,N) etc.

数组的维数是其在 Numba 中的类型的一部分。这意味着具有不同维度的 A 数组具有不同的类型,因此处理它的函数需要编译为多种类型。实际上,Numba 需要为 编译函数 的所有 input/output 参数(和内部变量)设置 well-defined 类型 .由于在运行时定义了可变数量的维度,您需要可变数量的编译函数和一个 Python 包装器来选择正确的函数实现。问题是编译一个函数非常昂贵,最大的问题是你指出的数组索引。

In python, I can do so using tuples

这是因为 Python 对 dynamically-typed 个对象进行操作。但这也是 Python 与 Numba 相比如此慢的主要原因之一。

indexing an array with a tuple is apparently not supported by numba

应该支持使用元组索引数组,但在 Numba 中不可能从动态 variable-sized 数组创建元组(而且肯定永远不会)。实际上,元组的类型(在编译时需要)基本上由其项目列表组成,这些项目依赖于动态 属性(仅在运行时才知道)。有关这方面的更多信息,请阅读:Numba Creating a tuple from a list .

Is that the best solution or is there a simple "numba alternative" to indexing with a tuple?

总的来说,您当然应该放弃使用 Numba 操作 variable-dimension 数组的想法,因为它不是为此而设计的。话虽这么说,但这并不意味着您无法使用 Numba 解决此类问题。

一种解决方案是对展平数组和展平索引进行操作。这个想法是用 A.ravel() 压平 A,然后将其交给 Numba。形状也可以这样做:A.shape 可以使用 shape = np.array(A.shape, dtype=np.int) 作为动态一维数组传递给 Numba(在 Numba 函数之外执行)。为了清楚(和理智)起见,我假设数组是连续的。然后,Numba 函数可以使用如下表达式访问 A 的给定项目:

A[idx[0]*strides[0] + idx[1]*strides[1] + ... + idx[n-1]*strides[n-1]]

其中 n = len(shape)。步长可以预先计算一次使用:

strides = np.array(n, dtype=np.int64)
for i in range(n):
    strides[i] = 1
    for j in range(i+1, n):
        strides[i] *= shape[j]

您可以使用基本循环在运行时计算展平索引:

pos = np.int64(0)
for i in range(n):
    pos += np.int64(idx[i]) * strides[i]
# A[pos]

请注意,执行此类操作效率低下,尤其是对于连续访问。也就是说,动态 variable-sized 数组从根本上说是低效的。您可以为某些特定的 n 值编译 Numba 代码,以加快这些特定情况(例如,具有少量维度的数组,如 <=8)。对于具有非常量维度的 high-order 数组,AFAIK 没有快速通用的方法来做到这一点。您需要对展平数组的 (relatively-large) 连续切片 进行操作,以帮助 Numba 生成更快的代码。