如何在 numba 中创建一个 np.array 与输入相关的等级?

How to make an np.array in numba with input-dependent rank?

我想 @numba.njit 这个简单的函数 returns 一个数组,其形状,特别是等级,取决于输入 i: 例如。对于 i = 4,形状应为 shape=(2, 2, 2, 2, 4)

import numpy as np
from numba import njit

@njit
def make_array_numba(i):
    shape = np.array([2] * i + [i], dtype=np.int64)
    return np.empty(shape, dtype=np.int64)

make_array_numba(4).shape

我尝试了很多不同的方法,但总是失败,因为我无法生成 numba 想要在 np.empty / [=20= 中看到的 shape 元组] / np.zeros /... 在正常的 numpy 中,可以将列表 / np.arrays 作为 shape 传递,或者我可以动态生成一个元组,例如 (2,) * i + (i,).

输出:

>>> empty(array(int64, 1d, C), dtype=class(int64))
 
There are 4 candidate implementations:
      - Of which 4 did not match due to:
      Overload in function '_OverloadWrapper._build.<locals>.ol_generated': File: numba/core/overload_glue.py: Line 131.
        With argument(s): '(array(int64, 1d, C), dtype=class(int64))':
       Rejected as the implementation raised a specific error:
         TypingError: Failed in nopython mode pipeline (step: nopython frontend)
       No implementation of function Function(<intrinsic stub>) found for signature:
        
        >>> stub(array(int64, 1d, C), class(int64))
        
       There are 2 candidate implementations:
         - Of which 2 did not match due to:
         Intrinsic of function 'stub': File: numba/core/overload_glue.py: Line 35.
           With argument(s): '(array(int64, 1d, C), class(int64))':
          No match.

只有 @njit 才能做到这一点。原因是 Numba 需要为数组设置一个独立于变量值的类型,以便编译函数然后才执行它。问题是数组的 维度是其类型 的一部分。因此,在这里,Numba 无法找到数组的类型,因为它依赖于一个不是 compile-time 常量.

的值

解决这个问题的唯一方法(假设你不想线性化你的数组)是为每个可能的 i 重新编译函数,这肯定是矫枉过正并且完全破坏了使用 Numba 的好处(在至少在你的例子中)。请注意,当您真的想为不同的值或输入类型重新编译函数时,可以在这种情况下使用 @generated_jit 。我强烈建议您不要将它用于当前 use-case。如果您尝试,那么由于无法使用 runtime-defined 变量对数组进行索引,您还会遇到其他类似的问题,并且生成的代码很快就会变得疯狂。

一个更通用和更简洁的解决方案是简单地线性化数组。这意味着将其展平并执行一些奇特的索引计算,如 (((... + z) * stride_z) + y) * stride_y + x。大小和索引可以在运行时独立于类型系统计算。请注意,索引可能会很慢,但在这种情况下 Numpy 不会使用更快的代码。