从 Numba 调用 Cython 扩展类型 jitted class
Calling Cython extension types from Numba jitted class
我将如何从 Numba jitted class 中调用 Cython 扩展类型的方法?我下面的最小示例失败,并出现我在下面记录的错误。我将如何修改我的最小示例以使其工作?
感谢您的帮助!!
最小示例
我有一个 Cython 模块,shrubbery.pyx
:
cdef class Shrubbery:
cdef int height
def __init__(self, h):
self.height = h
def describe(self):
print('This shrubbery is', self.height, 'tall.')
我有一个安装文件setup.py
:
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
ext_modules = [Extension('shrubbery', ['shrubbery.pyx'])]
setup(
name='shrubbery',
cmdclass={'build_ext': build_ext},
ext_modules=ext_modules)
我像往常一样将shrubbery.pyx
编译成扩展类型(python setup.py build_ext --inplace
)。然后我尝试在 numba jitted class 中使用 Shrubbery
如下:
from shrubbery import Shrubbery
import numba as nb
spec = [('value', nb.int32)]
@nb.jitclass(spec)
class Bag(object):
def __init__(self, value):
self.value = value
def size(self):
return self.value
def mixed_class_method(self):
__shrubbery = Shrubbery(5)
__shrubbery.describe()
# pure numba class: works
_b = Bag(value=3)
print(_b.size())
# pure cython extension type: works
__shrubbery = Shrubbery(5)
__shrubbery.describe()
# mix of cython extension type and numba jitted class: fails
_b.mixed_class_method()
错误
/Users/mg/anaconda/bin/python3 test.py
3
('This shrubbery is', 5, 'tall.')
Traceback (most recent call last):
File "test.py", line 28, in <module>
_b.mixed_class_method()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/jitclass/boxing.py", line 62, in wrapper
return method(*args, **kwargs)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
raise e
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
return self.compile(tuple(argtypes))
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 579, in compile
cres = self._compiler.compile(args, return_type)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 80, in compile
flags=flags, locals=self.locals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 779, in compile_extra
return pipeline.compile_extra(func)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 362, in compile_extra
return self._compile_bytecode()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 738, in _compile_bytecode
return self._compile_core()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 725, in _compile_core
res = pm.run(self.status)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 248, in run
raise patched_exception
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 240, in run
stage()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 454, in stage_nopython_frontend
self.locals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 881, in type_inference_stage
infer.propagate()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 846, in propagate
raise errors[0]
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 137, in propagate
constraint(typeinfer)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 415, in __call__
self.resolve(typeinfer, typevars, fnty)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 441, in resolve
sig = typeinfer.resolve_call(fnty, pos_args, kw_args, literals=literals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 1115, in resolve_call
literals=literals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typing/context.py", line 204, in resolve_function_type
return func.get_call_type_with_literals(self, args, kws, literals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/types/functions.py", line 199, in get_call_type_with_literals
return self.get_call_type(context, args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/types/functions.py", line 193, in get_call_type
return self.template(context).apply(args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typing/templates.py", line 207, in apply
sig = generic(args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/jitclass/base.py", line 322, in generic
sig = disp_type.get_call_type(self.context, args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/types/functions.py", line 250, in get_call_type
template, pysig, args, kws = self.dispatcher.get_call_template(args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 269, in get_call_template
self.compile(tuple(args))
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 579, in compile
cres = self._compiler.compile(args, return_type)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 80, in compile
flags=flags, locals=self.locals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 779, in compile_extra
return pipeline.compile_extra(func)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 362, in compile_extra
return self._compile_bytecode()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 738, in _compile_bytecode
return self._compile_core()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 725, in _compile_core
res = pm.run(self.status)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 248, in run
raise patched_exception
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 240, in run
stage()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 454, in stage_nopython_frontend
self.locals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 880, in type_inference_stage
infer.build_constraint()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 802, in build_constraint
self.constrain_statement(inst)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 961, in constrain_statement
self.typeof_assign(inst)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 1023, in typeof_assign
self.typeof_global(inst, inst.target, value)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 1119, in typeof_global
typ = self.resolve_value_type(inst, gvar.value)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 1042, in resolve_value_type
raise TypingError(msg, loc=inst.loc)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Failed at nopython (nopython frontend)
Untyped global name 'Shrubbery': cannot determine Numba type of <class 'type'>
File "test.py", line 16
[1] During: resolving callee type: BoundFunction((<class 'numba.types.misc.ClassInstanceType'>, 'mixed_class_method') for instance.jitclass.Bag#7fef29835df8<value:int32>)
[2] During: typing of call at <string> (3)
来自 numba 文档:
"All methods of a jitclass is compiled into nopython functions. The data of a jitclass instance is allocated on the heap as a C-compatible structure so that any compiled functions can have direct access to the underlying data, bypassing the interpreter."
正如 DavidW 指出的那样,Shrubbery 是 Python 类型而不是 C 类型,因此您不能在 jitclass 中使用。
不过您可以 jit 各个方法。
这主要是对您在评论中提出的可以使 CFFI 函数起作用的建议的回应。这是事实,但非常 有限。
您可以通过 C 函数指针将 Cython cdef
函数转换为 CFFI 函数。这种转换必须在 Cython 中进行。为了在 nopython
模式下使用 Numba,cdef
函数不得使用或 return a Python object。这意味着你的 Shrubbery
class 是不可能的。一个只有 accepts/returns C 类型可以工作的简单函数
from libc.stdint cimport uintptr_t
cdef void f(int x) nogil:
with gil:
print(x+1)
ctypedef void (*void_int_func_pointer)(int)
def get_cffi_f():
cdef void_int_func_pointer f_ptr = f
cdef uintptr_t f_ptr_int = <uintptr_t>f_ptr
from cffi import FFI
ffi = FFI()
return ffi.cast('void (*)(int)',f_ptr_int)
在 Python 中,您调用 call get_cffi_f()
以获得 f
的 CFFI 包装以传递给 Numba 函数。请注意,我已将该函数声明为 nogil
并在其中捕获了 GIL - 我不是 100% 确定 Numba 是否释放了 GIL,所以我这样做是为了安全起见。可能没有必要。
然后您可以将这些 CFFI 包装传递到 Numba 或将它们作为全局变量访问:
import numba as nb
from cy import get_cffi_f
func_global = get_cffi_f()
@nb.jit(nopython=True)
def simple_func(func):
func(5)
func_global(6)
func(7)
@nb.jitclass([('value', nb.int32)])
class Bag(object):
def __init__(self,value):
self.value = value
def mixed_class_method(self,func):
func(self.value)
func_global(self.value-1)
simple_func(get_cffi_f())
Bag(3).mixed_class_method(get_cffi_f())
我的观点是,试图让像 Cython class 这样的东西在这里工作是一个失败的原因。
可能还有其他方法可以实现相同的目的 - 您可以让 Cython 使用 api
或 public
制作 headers 并将这些 headers 与 CFFI 一起使用。
我将如何从 Numba jitted class 中调用 Cython 扩展类型的方法?我下面的最小示例失败,并出现我在下面记录的错误。我将如何修改我的最小示例以使其工作?
感谢您的帮助!!
最小示例
我有一个 Cython 模块,shrubbery.pyx
:
cdef class Shrubbery:
cdef int height
def __init__(self, h):
self.height = h
def describe(self):
print('This shrubbery is', self.height, 'tall.')
我有一个安装文件setup.py
:
from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext
ext_modules = [Extension('shrubbery', ['shrubbery.pyx'])]
setup(
name='shrubbery',
cmdclass={'build_ext': build_ext},
ext_modules=ext_modules)
我像往常一样将shrubbery.pyx
编译成扩展类型(python setup.py build_ext --inplace
)。然后我尝试在 numba jitted class 中使用 Shrubbery
如下:
from shrubbery import Shrubbery
import numba as nb
spec = [('value', nb.int32)]
@nb.jitclass(spec)
class Bag(object):
def __init__(self, value):
self.value = value
def size(self):
return self.value
def mixed_class_method(self):
__shrubbery = Shrubbery(5)
__shrubbery.describe()
# pure numba class: works
_b = Bag(value=3)
print(_b.size())
# pure cython extension type: works
__shrubbery = Shrubbery(5)
__shrubbery.describe()
# mix of cython extension type and numba jitted class: fails
_b.mixed_class_method()
错误
/Users/mg/anaconda/bin/python3 test.py
3
('This shrubbery is', 5, 'tall.')
Traceback (most recent call last):
File "test.py", line 28, in <module>
_b.mixed_class_method()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/jitclass/boxing.py", line 62, in wrapper
return method(*args, **kwargs)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
raise e
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
return self.compile(tuple(argtypes))
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 579, in compile
cres = self._compiler.compile(args, return_type)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 80, in compile
flags=flags, locals=self.locals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 779, in compile_extra
return pipeline.compile_extra(func)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 362, in compile_extra
return self._compile_bytecode()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 738, in _compile_bytecode
return self._compile_core()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 725, in _compile_core
res = pm.run(self.status)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 248, in run
raise patched_exception
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 240, in run
stage()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 454, in stage_nopython_frontend
self.locals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 881, in type_inference_stage
infer.propagate()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 846, in propagate
raise errors[0]
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 137, in propagate
constraint(typeinfer)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 415, in __call__
self.resolve(typeinfer, typevars, fnty)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 441, in resolve
sig = typeinfer.resolve_call(fnty, pos_args, kw_args, literals=literals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 1115, in resolve_call
literals=literals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typing/context.py", line 204, in resolve_function_type
return func.get_call_type_with_literals(self, args, kws, literals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/types/functions.py", line 199, in get_call_type_with_literals
return self.get_call_type(context, args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/types/functions.py", line 193, in get_call_type
return self.template(context).apply(args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typing/templates.py", line 207, in apply
sig = generic(args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/jitclass/base.py", line 322, in generic
sig = disp_type.get_call_type(self.context, args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/types/functions.py", line 250, in get_call_type
template, pysig, args, kws = self.dispatcher.get_call_template(args, kws)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 269, in get_call_template
self.compile(tuple(args))
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 579, in compile
cres = self._compiler.compile(args, return_type)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/dispatcher.py", line 80, in compile
flags=flags, locals=self.locals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 779, in compile_extra
return pipeline.compile_extra(func)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 362, in compile_extra
return self._compile_bytecode()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 738, in _compile_bytecode
return self._compile_core()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 725, in _compile_core
res = pm.run(self.status)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 248, in run
raise patched_exception
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 240, in run
stage()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 454, in stage_nopython_frontend
self.locals)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/compiler.py", line 880, in type_inference_stage
infer.build_constraint()
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 802, in build_constraint
self.constrain_statement(inst)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 961, in constrain_statement
self.typeof_assign(inst)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 1023, in typeof_assign
self.typeof_global(inst, inst.target, value)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 1119, in typeof_global
typ = self.resolve_value_type(inst, gvar.value)
File "/Users/mg/anaconda/lib/python3.6/site-packages/numba/typeinfer.py", line 1042, in resolve_value_type
raise TypingError(msg, loc=inst.loc)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Failed at nopython (nopython frontend)
Untyped global name 'Shrubbery': cannot determine Numba type of <class 'type'>
File "test.py", line 16
[1] During: resolving callee type: BoundFunction((<class 'numba.types.misc.ClassInstanceType'>, 'mixed_class_method') for instance.jitclass.Bag#7fef29835df8<value:int32>)
[2] During: typing of call at <string> (3)
来自 numba 文档:
"All methods of a jitclass is compiled into nopython functions. The data of a jitclass instance is allocated on the heap as a C-compatible structure so that any compiled functions can have direct access to the underlying data, bypassing the interpreter."
正如 DavidW 指出的那样,Shrubbery 是 Python 类型而不是 C 类型,因此您不能在 jitclass 中使用。
不过您可以 jit 各个方法。
这主要是对您在评论中提出的可以使 CFFI 函数起作用的建议的回应。这是事实,但非常 有限。
您可以通过 C 函数指针将 Cython cdef
函数转换为 CFFI 函数。这种转换必须在 Cython 中进行。为了在 nopython
模式下使用 Numba,cdef
函数不得使用或 return a Python object。这意味着你的 Shrubbery
class 是不可能的。一个只有 accepts/returns C 类型可以工作的简单函数
from libc.stdint cimport uintptr_t
cdef void f(int x) nogil:
with gil:
print(x+1)
ctypedef void (*void_int_func_pointer)(int)
def get_cffi_f():
cdef void_int_func_pointer f_ptr = f
cdef uintptr_t f_ptr_int = <uintptr_t>f_ptr
from cffi import FFI
ffi = FFI()
return ffi.cast('void (*)(int)',f_ptr_int)
在 Python 中,您调用 call get_cffi_f()
以获得 f
的 CFFI 包装以传递给 Numba 函数。请注意,我已将该函数声明为 nogil
并在其中捕获了 GIL - 我不是 100% 确定 Numba 是否释放了 GIL,所以我这样做是为了安全起见。可能没有必要。
然后您可以将这些 CFFI 包装传递到 Numba 或将它们作为全局变量访问:
import numba as nb
from cy import get_cffi_f
func_global = get_cffi_f()
@nb.jit(nopython=True)
def simple_func(func):
func(5)
func_global(6)
func(7)
@nb.jitclass([('value', nb.int32)])
class Bag(object):
def __init__(self,value):
self.value = value
def mixed_class_method(self,func):
func(self.value)
func_global(self.value-1)
simple_func(get_cffi_f())
Bag(3).mixed_class_method(get_cffi_f())
我的观点是,试图让像 Cython class 这样的东西在这里工作是一个失败的原因。
可能还有其他方法可以实现相同的目的 - 您可以让 Cython 使用 api
或 public
制作 headers 并将这些 headers 与 CFFI 一起使用。