在 numba jitclass 中有 float64 或 float32 属性
Have float64 or float32 attribute in numba jitclass
如何使用参数可以是 float64 或 float32 的 numba jitclass?使用函数,以下代码有效:
import numba
import numpy as np
from numba import njit
from numba.experimental import jitclass
@njit()
def f(a):
print(a.dtype)
return a[0]
a = np.zeros(3)
f(a)
f(a.astype(np.float32))
尝试将 float32 和 float64 与 class 属性一起使用时失败:
@jitclass([('arr', numba.types.float64[:])])
class MyClass():
def __init__(self):
pass
def f(self, a):
self.arr = a
myclass = MyClass()
myclass.f(np.zeros(3))
# following line fails:
myclass.f(np.zeros(3, dtype=np.float32))
有解决办法吗?
当您调用 MyClass()
时,Numba 需要实例化一个 class 并且因为 Numba 仅适用于 well-defined 强类型(这就是它快速且有用的原因), class 的字段需要在对象实例化之前输入。因此,当调用方法 f
时,您不能定义 MyClass
字段的类型,因为此调用是由动态的 CPython 解释器进行的。请注意,class 通常有不止一种方法(否则这样的 class 不会很有用),这就是部分编译也不可能的原因。
解决此问题的一个简单方法是使用两种类型:
class MyClass():
def __init__(self):
pass
def f(self, a):
self.arr = a
MyClass_float32 = jitclass([('arr', numba.types.float32[:])])(MyClass)
MyClass_float64 = jitclass([('arr', numba.types.float64[:])])(MyClass)
myclass = MyClass_float32() # Instantiate the class lazily and an object
# `self.arr` is already instantiated here and it has `float32[:]` type.
myclass.f(np.zeros(3, dtype=np.float32))
myclass = MyClass_float64()
myclass.f(np.zeros(3, dtype=np.float64))
Numba 支持模板化内核,但通过自定义类型选择,如下所示:
import numpy as np
from numba import generated_jit, types
@generated_jit(nopython=True)
def is_missing(x):
"""
Return True if the value is missing, False otherwise.
"""
if isinstance(x, types.Float):
return lambda x: np.isnan(x)
elif isinstance(x, (types.NPDatetime, types.NPTimedelta)):
# The corresponding Not-a-Time value
missing = x('NaT')
return lambda x: x == missing
else:
return lambda x: Fals
诀窍是使用 'generated_jit' 注释。
如何使用参数可以是 float64 或 float32 的 numba jitclass?使用函数,以下代码有效:
import numba
import numpy as np
from numba import njit
from numba.experimental import jitclass
@njit()
def f(a):
print(a.dtype)
return a[0]
a = np.zeros(3)
f(a)
f(a.astype(np.float32))
尝试将 float32 和 float64 与 class 属性一起使用时失败:
@jitclass([('arr', numba.types.float64[:])])
class MyClass():
def __init__(self):
pass
def f(self, a):
self.arr = a
myclass = MyClass()
myclass.f(np.zeros(3))
# following line fails:
myclass.f(np.zeros(3, dtype=np.float32))
有解决办法吗?
当您调用 MyClass()
时,Numba 需要实例化一个 class 并且因为 Numba 仅适用于 well-defined 强类型(这就是它快速且有用的原因), class 的字段需要在对象实例化之前输入。因此,当调用方法 f
时,您不能定义 MyClass
字段的类型,因为此调用是由动态的 CPython 解释器进行的。请注意,class 通常有不止一种方法(否则这样的 class 不会很有用),这就是部分编译也不可能的原因。
解决此问题的一个简单方法是使用两种类型:
class MyClass():
def __init__(self):
pass
def f(self, a):
self.arr = a
MyClass_float32 = jitclass([('arr', numba.types.float32[:])])(MyClass)
MyClass_float64 = jitclass([('arr', numba.types.float64[:])])(MyClass)
myclass = MyClass_float32() # Instantiate the class lazily and an object
# `self.arr` is already instantiated here and it has `float32[:]` type.
myclass.f(np.zeros(3, dtype=np.float32))
myclass = MyClass_float64()
myclass.f(np.zeros(3, dtype=np.float64))
Numba 支持模板化内核,但通过自定义类型选择,如下所示:
import numpy as np
from numba import generated_jit, types
@generated_jit(nopython=True)
def is_missing(x):
"""
Return True if the value is missing, False otherwise.
"""
if isinstance(x, types.Float):
return lambda x: np.isnan(x)
elif isinstance(x, (types.NPDatetime, types.NPTimedelta)):
# The corresponding Not-a-Time value
missing = x('NaT')
return lambda x: x == missing
else:
return lambda x: Fals
诀窍是使用 'generated_jit' 注释。