融合类型的 Cython 缓存

Cython caching with fused type

我正在尝试将融合类型添加到 scipy.stats.qmc._sobol.pyx

问题是我们正在缓存两个矩阵,以便下次运行函数时不需要加载矩阵。我们如下声明它们,并在函数中使用一些全局变量来填充它们。

cdef cnp.uint64_t poly[MAXDIM]
cdef cnp.uint64_t vinit[MAXDIM][MAXDEG]

如果我尝试使用融合类型而不是 cnp.uint64_t,我会遇到编译错误。事实上,Cython 不能在这里决定类型。

我想到的解决办法是声明两次矩阵。一个用于 cnp.uint32_t,一个用于 cnp.uint64_t,然后我可以检测每个函数中是否需要第一组或第二组。但是恐怕会增加内存占用。我考虑过释放数组之一,但是如果用户同时调用矩阵并要求 32 位和 64 位,那么它可能会中断。

cdef cnp.uint32_t poly_32[MAXDIM]
cdef cnp.uint32_t vinit_32[MAXDIM][MAXDEG]

cdef cnp.uint64_t poly_64[MAXDIM]
cdef cnp.uint64_t vinit_64[MAXDIM][MAXDEG]

是否有替代方法来缓存矩阵并使用融合类型?矩阵需要是 cnp.uint32_tcnp.uint64_t。并不是说 64 位架构的人可以要求使用 32 位的功能,所以我不能真正限制 64 位架构上的 64 位。


这里有一些或多或少完整的代码来解释 Cython 和 Python 中的整个逻辑:

cdef cnp.uint64_t poly[MAXDIM]
cdef cnp.uint64_t vinit[MAXDIM][MAXDEG]

cdef bint is_initialized = False

def _initialize_direction_numbers():
    global is_initialized
    if not is_initialized:
        for i in range(...):
            poly[i] = ...
            vinit[i] = ...
        is_initialized = True

def _initialize_v(...):
    # use the cached values
    for i in range(...):
        ... = poly[i]
        ... = vinit[i]

在Python我有

_initialize_direction_numbers()
_initialize_v(...)

再次调用这 2 个函数将不会再次加载矩阵,因为 is_initialized = True.

我不会对 C 数组执行此操作(即 cdef cnp.uint64_t poly[MAXDIM])。它们的缺点是:

  • 无论是否实际初始化,它们都会使用内存
  • 它们很可能会生成缺少 stack-allocated 临时对象,这可能会导致错误(尽管数组本身不会分配堆栈)。

相反,我可能会使用 dict 个 Numpy 数组。这实际上并不涉及使用融合类型。

_poly_dict = {}
_vinit_dict = {}

def get_poly(dtype):
    poly = _poly_dict.get(dtype)
    if not poly:
       _poly_dict[dtype] = np.empty(..., dtype=dtype)
       # ... initialize it
    return poly

# etc.

然后您可以做的是创建这些数组的内存视图(可能在融合函数中)。内存视图的创建速度非常快,因为它们只访问现有内存。像

cdef fused int32or64:
   cnp.uint32_t
   cnp.uint64_t

def do_calculation(int32or64 user_value):
   # slightly awkward conversion from ctype to Numpy dtype 
   #  - if you have to do this often the use a helper function
   cdef int32or64[:] poly = get_poly(np.int32 if int32or64 is cnp.uint32_t else np.int64)
   # your calculation goes here...

顺便说一句,如果您想在 get_poly 中使用融合类型的内存视图(例如,初始化数组),添加一个伪参数通常很有用:

def get_poly(dtype, int32or64 dummy):
   ...

即使它没有自然的“输入”,您也可以将其生成为融合函数(从而避免重复代码)。