在 numba 中用数组索引数组
Index an array with an array in numba
作为我想编译的更复杂函数的一部分
numba,我必须用另一个数组 idx
.
索引一个数组 A
重要的是,数组的维数A
是可变的。
形状可以是(N)
、(N,N)
、(N,N,N)
等
在 python 中,我可以使用元组来做到这一点:
def test():
A = np.arange(5*5*5).reshape(5,5,5)
idx = np.array([0,2,4])
return A[tuple(idx)]
但是,显然不支持使用元组索引数组
numba 因为我收到以下错误:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<class 'tuple'>) found for signature:
>>> tuple(array(int64, 1d, C))
请注意,这只是一个最小的工作示例。一般来说,我不
知道 idx
.
的长度
我想将 A
重塑为向量并转换
idx
到相应的标量索引。
这是最好的解决方案还是有一个简单的“numba 替代方案”
用元组索引?
Importantly, the dimension of the array A is variable. The shape can be (N), (N,N), or (N,N,N) etc.
数组的维数是其在 Numba 中的类型的一部分。这意味着具有不同维度的 A
数组具有不同的类型,因此处理它的函数需要编译为多种类型。实际上,Numba 需要为 编译函数 的所有 input/output 参数(和内部变量)设置 well-defined 类型 .由于在运行时定义了可变数量的维度,您需要可变数量的编译函数和一个 Python 包装器来选择正确的函数实现。问题是编译一个函数非常昂贵,最大的问题是你指出的数组索引。
In python, I can do so using tuples
这是因为 Python 对 dynamically-typed 个对象进行操作。但这也是 Python 与 Numba 相比如此慢的主要原因之一。
indexing an array with a tuple is apparently not supported by numba
应该支持使用元组索引数组,但在 Numba 中不可能从动态 variable-sized 数组创建元组(而且肯定永远不会)。实际上,元组的类型(在编译时需要)基本上由其项目列表组成,这些项目依赖于动态 属性(仅在运行时才知道)。有关这方面的更多信息,请阅读:Numba Creating a tuple from a list .
Is that the best solution or is there a simple "numba alternative" to indexing with a tuple?
总的来说,您当然应该放弃使用 Numba 操作 variable-dimension 数组的想法,因为它不是为此而设计的。话虽这么说,但这并不意味着您无法使用 Numba 解决此类问题。
一种解决方案是对展平数组和展平索引进行操作。这个想法是用 A.ravel()
压平 A
,然后将其交给 Numba。形状也可以这样做:A.shape
可以使用 shape = np.array(A.shape, dtype=np.int)
作为动态一维数组传递给 Numba(在 Numba 函数之外执行)。为了清楚(和理智)起见,我假设数组是连续的。然后,Numba 函数可以使用如下表达式访问 A
的给定项目:
A[idx[0]*strides[0] + idx[1]*strides[1] + ... + idx[n-1]*strides[n-1]]
其中 n = len(shape)
。步长可以预先计算一次使用:
strides = np.array(n, dtype=np.int64)
for i in range(n):
strides[i] = 1
for j in range(i+1, n):
strides[i] *= shape[j]
您可以使用基本循环在运行时计算展平索引:
pos = np.int64(0)
for i in range(n):
pos += np.int64(idx[i]) * strides[i]
# A[pos]
请注意,执行此类操作效率低下,尤其是对于连续访问。也就是说,动态 variable-sized 数组从根本上说是低效的。您可以为某些特定的 n
值编译 Numba 代码,以加快这些特定情况(例如,具有少量维度的数组,如 <=8)。对于具有非常量维度的 high-order 数组,AFAIK 没有快速通用的方法来做到这一点。您需要对展平数组的 (relatively-large) 连续切片 进行操作,以帮助 Numba 生成更快的代码。
作为我想编译的更复杂函数的一部分
numba,我必须用另一个数组 idx
.
A
重要的是,数组的维数A
是可变的。
形状可以是(N)
、(N,N)
、(N,N,N)
等
在 python 中,我可以使用元组来做到这一点:
def test():
A = np.arange(5*5*5).reshape(5,5,5)
idx = np.array([0,2,4])
return A[tuple(idx)]
但是,显然不支持使用元组索引数组 numba 因为我收到以下错误:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<class 'tuple'>) found for signature:
>>> tuple(array(int64, 1d, C))
请注意,这只是一个最小的工作示例。一般来说,我不
知道 idx
.
我想将 A
重塑为向量并转换
idx
到相应的标量索引。
这是最好的解决方案还是有一个简单的“numba 替代方案” 用元组索引?
Importantly, the dimension of the array A is variable. The shape can be (N), (N,N), or (N,N,N) etc.
数组的维数是其在 Numba 中的类型的一部分。这意味着具有不同维度的 A
数组具有不同的类型,因此处理它的函数需要编译为多种类型。实际上,Numba 需要为 编译函数 的所有 input/output 参数(和内部变量)设置 well-defined 类型 .由于在运行时定义了可变数量的维度,您需要可变数量的编译函数和一个 Python 包装器来选择正确的函数实现。问题是编译一个函数非常昂贵,最大的问题是你指出的数组索引。
In python, I can do so using tuples
这是因为 Python 对 dynamically-typed 个对象进行操作。但这也是 Python 与 Numba 相比如此慢的主要原因之一。
indexing an array with a tuple is apparently not supported by numba
应该支持使用元组索引数组,但在 Numba 中不可能从动态 variable-sized 数组创建元组(而且肯定永远不会)。实际上,元组的类型(在编译时需要)基本上由其项目列表组成,这些项目依赖于动态 属性(仅在运行时才知道)。有关这方面的更多信息,请阅读:Numba Creating a tuple from a list .
Is that the best solution or is there a simple "numba alternative" to indexing with a tuple?
总的来说,您当然应该放弃使用 Numba 操作 variable-dimension 数组的想法,因为它不是为此而设计的。话虽这么说,但这并不意味着您无法使用 Numba 解决此类问题。
一种解决方案是对展平数组和展平索引进行操作。这个想法是用 A.ravel()
压平 A
,然后将其交给 Numba。形状也可以这样做:A.shape
可以使用 shape = np.array(A.shape, dtype=np.int)
作为动态一维数组传递给 Numba(在 Numba 函数之外执行)。为了清楚(和理智)起见,我假设数组是连续的。然后,Numba 函数可以使用如下表达式访问 A
的给定项目:
A[idx[0]*strides[0] + idx[1]*strides[1] + ... + idx[n-1]*strides[n-1]]
其中 n = len(shape)
。步长可以预先计算一次使用:
strides = np.array(n, dtype=np.int64)
for i in range(n):
strides[i] = 1
for j in range(i+1, n):
strides[i] *= shape[j]
您可以使用基本循环在运行时计算展平索引:
pos = np.int64(0)
for i in range(n):
pos += np.int64(idx[i]) * strides[i]
# A[pos]
请注意,执行此类操作效率低下,尤其是对于连续访问。也就是说,动态 variable-sized 数组从根本上说是低效的。您可以为某些特定的 n
值编译 Numba 代码,以加快这些特定情况(例如,具有少量维度的数组,如 <=8)。对于具有非常量维度的 high-order 数组,AFAIK 没有快速通用的方法来做到这一点。您需要对展平数组的 (relatively-large) 连续切片 进行操作,以帮助 Numba 生成更快的代码。