在 numba jitclass 中有 float64 或 float32 属性

Have float64 or float32 attribute in numba jitclass

如何使用参数可以是 float64 或 float32 的 numba jitclass?使用函数,以下代码有效:

import numba
import numpy as np
from numba import njit
from numba.experimental import jitclass


@njit()
def f(a):
    print(a.dtype)
    return a[0]


a = np.zeros(3)
f(a)
f(a.astype(np.float32))

尝试将 float32 和 float64 与 class 属性一起使用时失败:

@jitclass([('arr', numba.types.float64[:])])
class MyClass():
    def __init__(self):
        pass

    def f(self, a):
        self.arr = a


myclass = MyClass()
myclass.f(np.zeros(3))
# following line fails:
myclass.f(np.zeros(3, dtype=np.float32))

有解决办法吗?

当您调用 MyClass() 时,Numba 需要实例化一个 class 并且因为 Numba 仅适用于 well-defined 强类型(这就是它快速且有用的原因), class 的字段需要在对象实例化之前输入。因此,当调用方法 f 时,您不能定义 MyClass 字段的类型,因为此调用是由动态的 CPython 解释器进行的。请注意,class 通常有不止一种方法(否则这样的 class 不会很有用),这就是部分编译也不可能的原因。

解决此问题的一个简单方法是使用两种类型:

class MyClass():
    def __init__(self):
        pass

    def f(self, a):
        self.arr = a

MyClass_float32 = jitclass([('arr', numba.types.float32[:])])(MyClass)
MyClass_float64 = jitclass([('arr', numba.types.float64[:])])(MyClass)

myclass = MyClass_float32() # Instantiate the class lazily and an object
# `self.arr` is already instantiated here and it has `float32[:]` type.
myclass.f(np.zeros(3, dtype=np.float32))

myclass = MyClass_float64()
myclass.f(np.zeros(3, dtype=np.float64))

Numba 支持模板化内核,但通过自定义类型选择,如下所示:

import numpy as np

from numba import generated_jit, types

@generated_jit(nopython=True)
def is_missing(x):
    """
    Return True if the value is missing, False otherwise.
    """
    if isinstance(x, types.Float):
        return lambda x: np.isnan(x)
    elif isinstance(x, (types.NPDatetime, types.NPTimedelta)):
        # The corresponding Not-a-Time value
        missing = x('NaT')
        return lambda x: x == missing
    else:
        return lambda x: Fals

诀窍是使用 'generated_jit' 注释。