Numba jit 和延迟类型
Numba jit and deferred types
我将 numba 作为函数的签名传递
@numba.jit(numba.types.UniTuple(numba.float64[:, :], 2)(
numba.float64[:, :], numba.float64[:, :], numba.float64[:, :],
earth_model_type))
其中 earth_model_type
定义为
earth_model_type = numba.deferred_type()
earth_model_type.define(em.EarthModel.class_type.instance_type)
它编译得很好,但是当我尝试调用函数时我得到
*** TypeError: No matching definition for argument type(s) array(float64, 2d, F), array(float64, 2d, C), array(float64, 2d, F),
instance.jitclass.EarthModel#7fd9c48dd668
在我看来,定义不匹配的参数类型与我上面的类型几乎相同。另一方面,如果我不通过仅使用 @numba.jit(nopython=True)
指定签名,它工作正常并且由 numba 编译的函数的签名是
ipdb> numbed_cowell_propagator_propagate.signatures
[(array(float64, 2d, F), array(float64, 2d, C), array(float64, 2d, F),
instance.jitclass.EarthModel#7f81bbc0e780)]
编辑
如果我使用 FAQ 中的方式强制执行 C 序数组,我仍然会收到错误
TypeError: No matching definition for argument type(s) array(float64,
2d, C), array(float64, 2d, C), array(float64, 2d, C),
instance.jitclass.EarthModel#7f6edd8d57b8
我很确定问题与延迟类型有关,因为如果不是传递 jit class,我会传递我需要的所有属性 class (4 numba.float64
s),效果很好。
我在指定签名时做错了什么?
干杯。
在不完全了解您的完整代码如何工作的情况下,我不确定您为什么需要使用延迟类型。通常它用于包含相同类型实例变量的 jitclasses,例如链表或其他节点树,因此需要推迟到编译器处理 class 本身(参见 source)下面的最小示例有效(如果我使用延迟类型,我可以重现你的错误):
import numpy as np
import numba as nb
spec = [('x', nb.float64)]
@nb.jitclass(spec)
class EarthModel:
def __init__(self, x):
self.x = x
earth_model_type = EarthModel.class_type.instance_type
@nb.jit(nb.float64(nb.float64[:, :], nb.float64[:, :], nb.float64[:, :], earth_model_type))
def test(x, y, z, em):
return em.x
然后运行它:
em = EarthModel(9.9)
x = np.random.normal(size=(3,3))
y = np.random.normal(size=(3,3))
z = np.random.normal(size=(3,3))
res = test(x, y, z, em)
print(res) # 9.9
我将 numba 作为函数的签名传递
@numba.jit(numba.types.UniTuple(numba.float64[:, :], 2)(
numba.float64[:, :], numba.float64[:, :], numba.float64[:, :],
earth_model_type))
其中 earth_model_type
定义为
earth_model_type = numba.deferred_type()
earth_model_type.define(em.EarthModel.class_type.instance_type)
它编译得很好,但是当我尝试调用函数时我得到
*** TypeError: No matching definition for argument type(s) array(float64, 2d, F), array(float64, 2d, C), array(float64, 2d, F), instance.jitclass.EarthModel#7fd9c48dd668
在我看来,定义不匹配的参数类型与我上面的类型几乎相同。另一方面,如果我不通过仅使用 @numba.jit(nopython=True)
指定签名,它工作正常并且由 numba 编译的函数的签名是
ipdb> numbed_cowell_propagator_propagate.signatures
[(array(float64, 2d, F), array(float64, 2d, C), array(float64, 2d, F), instance.jitclass.EarthModel#7f81bbc0e780)]
编辑
如果我使用 FAQ 中的方式强制执行 C 序数组,我仍然会收到错误
TypeError: No matching definition for argument type(s) array(float64, 2d, C), array(float64, 2d, C), array(float64, 2d, C), instance.jitclass.EarthModel#7f6edd8d57b8
我很确定问题与延迟类型有关,因为如果不是传递 jit class,我会传递我需要的所有属性 class (4 numba.float64
s),效果很好。
我在指定签名时做错了什么?
干杯。
在不完全了解您的完整代码如何工作的情况下,我不确定您为什么需要使用延迟类型。通常它用于包含相同类型实例变量的 jitclasses,例如链表或其他节点树,因此需要推迟到编译器处理 class 本身(参见 source)下面的最小示例有效(如果我使用延迟类型,我可以重现你的错误):
import numpy as np
import numba as nb
spec = [('x', nb.float64)]
@nb.jitclass(spec)
class EarthModel:
def __init__(self, x):
self.x = x
earth_model_type = EarthModel.class_type.instance_type
@nb.jit(nb.float64(nb.float64[:, :], nb.float64[:, :], nb.float64[:, :], earth_model_type))
def test(x, y, z, em):
return em.x
然后运行它:
em = EarthModel(9.9)
x = np.random.normal(size=(3,3))
y = np.random.normal(size=(3,3))
z = np.random.normal(size=(3,3))
res = test(x, y, z, em)
print(res) # 9.9