如何在 numba 中创建一个 np.array 与输入相关的等级?
How to make an np.array in numba with input-dependent rank?
我想 @numba.njit
这个简单的函数 returns 一个数组,其形状,特别是等级,取决于输入 i
:
例如。对于 i = 4
,形状应为 shape=(2, 2, 2, 2, 4)
import numpy as np
from numba import njit
@njit
def make_array_numba(i):
shape = np.array([2] * i + [i], dtype=np.int64)
return np.empty(shape, dtype=np.int64)
make_array_numba(4).shape
我尝试了很多不同的方法,但总是失败,因为我无法生成 numba
想要在 np.empty
/ [=20= 中看到的 shape
元组] / np.zeros
/...
在正常的 numpy
中,可以将列表 / np.arrays 作为 shape
传递,或者我可以动态生成一个元组,例如 (2,) * i + (i,)
.
输出:
>>> empty(array(int64, 1d, C), dtype=class(int64))
There are 4 candidate implementations:
- Of which 4 did not match due to:
Overload in function '_OverloadWrapper._build.<locals>.ol_generated': File: numba/core/overload_glue.py: Line 131.
With argument(s): '(array(int64, 1d, C), dtype=class(int64))':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<intrinsic stub>) found for signature:
>>> stub(array(int64, 1d, C), class(int64))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Intrinsic of function 'stub': File: numba/core/overload_glue.py: Line 35.
With argument(s): '(array(int64, 1d, C), class(int64))':
No match.
只有 @njit
才能做到这一点。原因是 Numba 需要为数组设置一个独立于变量值的类型,以便编译函数然后才执行它。问题是数组的 维度是其类型 的一部分。因此,在这里,Numba 无法找到数组的类型,因为它依赖于一个不是 compile-time 常量.
的值
解决这个问题的唯一方法(假设你不想线性化你的数组)是为每个可能的 i
重新编译函数,这肯定是矫枉过正并且完全破坏了使用 Numba 的好处(在至少在你的例子中)。请注意,当您真的想为不同的值或输入类型重新编译函数时,可以在这种情况下使用 @generated_jit
。我强烈建议您不要将它用于当前 use-case。如果您尝试,那么由于无法使用 runtime-defined 变量对数组进行索引,您还会遇到其他类似的问题,并且生成的代码很快就会变得疯狂。
一个更通用和更简洁的解决方案是简单地线性化数组。这意味着将其展平并执行一些奇特的索引计算,如 (((... + z) * stride_z) + y) * stride_y + x
。大小和索引可以在运行时独立于类型系统计算。请注意,索引可能会很慢,但在这种情况下 Numpy 不会使用更快的代码。
我想 @numba.njit
这个简单的函数 returns 一个数组,其形状,特别是等级,取决于输入 i
:
例如。对于 i = 4
,形状应为 shape=(2, 2, 2, 2, 4)
import numpy as np
from numba import njit
@njit
def make_array_numba(i):
shape = np.array([2] * i + [i], dtype=np.int64)
return np.empty(shape, dtype=np.int64)
make_array_numba(4).shape
我尝试了很多不同的方法,但总是失败,因为我无法生成 numba
想要在 np.empty
/ [=20= 中看到的 shape
元组] / np.zeros
/...
在正常的 numpy
中,可以将列表 / np.arrays 作为 shape
传递,或者我可以动态生成一个元组,例如 (2,) * i + (i,)
.
输出:
>>> empty(array(int64, 1d, C), dtype=class(int64))
There are 4 candidate implementations:
- Of which 4 did not match due to:
Overload in function '_OverloadWrapper._build.<locals>.ol_generated': File: numba/core/overload_glue.py: Line 131.
With argument(s): '(array(int64, 1d, C), dtype=class(int64))':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<intrinsic stub>) found for signature:
>>> stub(array(int64, 1d, C), class(int64))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Intrinsic of function 'stub': File: numba/core/overload_glue.py: Line 35.
With argument(s): '(array(int64, 1d, C), class(int64))':
No match.
只有 @njit
才能做到这一点。原因是 Numba 需要为数组设置一个独立于变量值的类型,以便编译函数然后才执行它。问题是数组的 维度是其类型 的一部分。因此,在这里,Numba 无法找到数组的类型,因为它依赖于一个不是 compile-time 常量.
解决这个问题的唯一方法(假设你不想线性化你的数组)是为每个可能的 i
重新编译函数,这肯定是矫枉过正并且完全破坏了使用 Numba 的好处(在至少在你的例子中)。请注意,当您真的想为不同的值或输入类型重新编译函数时,可以在这种情况下使用 @generated_jit
。我强烈建议您不要将它用于当前 use-case。如果您尝试,那么由于无法使用 runtime-defined 变量对数组进行索引,您还会遇到其他类似的问题,并且生成的代码很快就会变得疯狂。
一个更通用和更简洁的解决方案是简单地线性化数组。这意味着将其展平并执行一些奇特的索引计算,如 (((... + z) * stride_z) + y) * stride_y + x
。大小和索引可以在运行时独立于类型系统计算。请注意,索引可能会很慢,但在这种情况下 Numpy 不会使用更快的代码。