在 numba jit 函数中使用 numba jitclass 作为参数
Use numba jitclass as a parameter in numba jit function
我正在使用 numba 0.46.0,我想将 class 的一个对象作为参数传递给我的函数,并 运行 在我的 GPU 上使用 CUDA 传递这个函数。
如果我想使用一个简单的 Python 对象(比如 int
),我会使用这样的东西:
from numba import jit, cuda
from numba.types import void, int32
@jit(void(int32), target='cuda')
def f(int_object):
pass
f(123)
这很好用。现在我尝试用 class:
做同样的事情
from numba import jit, cuda
from numba,types import void
@jitclass([])
class MyClass:
def __init__(self):
pass
@jit(void(MyClass), target='cuda')
def f(MyClass_object):
pass
这失败了 NotImplementedError
,没有任何评论。我也尝试以懒惰的方式编译它:
@jit(target='cuda')
def f(MyClass_object):
pass
f(MyClass())
这失败了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/dispatcher.py", line 42, in __call__
return self.compiled(*args, **kws)
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/compiler.py", line 801, in __call__
cfg(*args)
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/compiler.py", line 537, in __call__
sharedmem=self.sharedmem)
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/compiler.py", line 604, in _kernel_call
self._prepare_args(t, v, stream, retr, kernelargs)
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/compiler.py", line 715, in _prepare_args
raise NotImplementedError(ty, val)
NotImplementedError: (instance.jitclass.MyClass#7f983418fc88<>, <numba.jitclass.boxing.MyClass object at 0x7f983416ca10>)
我可以使用 jitclass 对象作为 jit 函数参数吗?如果是,上面的例子有什么问题?
更新:
顺便说一下,我已经用 numpy 数组检查过这个:
import numpy as np
from numba import jit, cuda
from numba.types import void
@jit(void(np.ndarray), target='cuda')
def f1(ndarray_object):
pass
# Fails with NotImplementedError with no comments
@jit(target='cuda')
def f2(ndarray_object):
pass
a = np.asarray([])
f2(a) # Executes with no errors, only a warning about autojit
为什么这适用于 numpy,但不适用于我的 class?为什么这适用于惰性模式 (f2) 中的 numpy,但不适用于给定的签名 (f1)?
根据相关 documentation(撰写本文时 Numba 0.47):
Support for jitclasses are available on CPU only. (Note: Support for
GPU devices is planned for a future release.)
我正在使用 numba 0.46.0,我想将 class 的一个对象作为参数传递给我的函数,并 运行 在我的 GPU 上使用 CUDA 传递这个函数。
如果我想使用一个简单的 Python 对象(比如 int
),我会使用这样的东西:
from numba import jit, cuda
from numba.types import void, int32
@jit(void(int32), target='cuda')
def f(int_object):
pass
f(123)
这很好用。现在我尝试用 class:
做同样的事情from numba import jit, cuda
from numba,types import void
@jitclass([])
class MyClass:
def __init__(self):
pass
@jit(void(MyClass), target='cuda')
def f(MyClass_object):
pass
这失败了 NotImplementedError
,没有任何评论。我也尝试以懒惰的方式编译它:
@jit(target='cuda')
def f(MyClass_object):
pass
f(MyClass())
这失败了
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/dispatcher.py", line 42, in __call__
return self.compiled(*args, **kws)
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/compiler.py", line 801, in __call__
cfg(*args)
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/compiler.py", line 537, in __call__
sharedmem=self.sharedmem)
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/compiler.py", line 604, in _kernel_call
self._prepare_args(t, v, stream, retr, kernelargs)
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/compiler.py", line 715, in _prepare_args
raise NotImplementedError(ty, val)
NotImplementedError: (instance.jitclass.MyClass#7f983418fc88<>, <numba.jitclass.boxing.MyClass object at 0x7f983416ca10>)
我可以使用 jitclass 对象作为 jit 函数参数吗?如果是,上面的例子有什么问题?
更新: 顺便说一下,我已经用 numpy 数组检查过这个:
import numpy as np
from numba import jit, cuda
from numba.types import void
@jit(void(np.ndarray), target='cuda')
def f1(ndarray_object):
pass
# Fails with NotImplementedError with no comments
@jit(target='cuda')
def f2(ndarray_object):
pass
a = np.asarray([])
f2(a) # Executes with no errors, only a warning about autojit
为什么这适用于 numpy,但不适用于我的 class?为什么这适用于惰性模式 (f2) 中的 numpy,但不适用于给定的签名 (f1)?
根据相关 documentation(撰写本文时 Numba 0.47):
Support for jitclasses are available on CPU only. (Note: Support for GPU devices is planned for a future release.)