在 python numba jitclass 中调用 njit 函数失败
calling njit function in python numba jitclass fails
@njit
def cumutrapz(x:np.array, y:np.array):
return np.append(0, [
np.trapz(y=y[i-2:i], x=x[i-2:i]) for i in range(2, len(x) + 1)]).cumsum()
from numba import float64
@jitclass([
('a', float64[:]),
('b', float64[:]),
('c', float64[:]),
])
class Testaroo(object):
def __init__(self, a, b):
self.a = a
self.b = b
self.c = np.zeros(len(self.a), dtype=np.float64)
def set_c(self):
self.c = cumutrapz(self.a, self.b)
testaroo = Testaroo(
np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
testaroo.set_c()
以上失败,但以下两个非常相似的示例有效:
cumutrapz(np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
和
from numba import float64
@jitclass([
('a', float64[:]),
('b', float64[:]),
('c', float64[:]),
])
class Testaroo(object):
def __init__(self, a, b):
self.a = a
self.b = b
self.c = np.zeros(len(self.a), dtype=np.float64)
def set_c(self):
self.c = (self.a * self.b).cumsum()
testaroo = Testaroo(
np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
testaroo.set_c()
后一个示例现在对我有用,但我想知道是否有办法让 cumutrapz
函数在 jitclass
.
中运行
我正在使用 numba 版本“0.53.1”。
仔细阅读长长的错误信息,您会发现:
No implementation of function Function(<function trapz at 0x7f7e9b21e5e0>)
found for signature:
>>> trapz(y=array(float64, 1d, A), x=array(float64, 1d, A))
...
reshape() supports contiguous array only
具有 format A (any) 的数组不一定是连续的。
您可以确保该函数仅处理 contiguous arrays:
@njit([nb.float64[::1](nb.float64[::1], nb.float64[::1])])
def cumutrapz(x, y):
...
然后出现新的错误:
Invalid use of type(CPUDispatcher(<function MyTestCase.test_cumutrapz.<locals>.cumutrapz at 0x7f8e15a841f0>))
with parameters (array(float64, 1d, A), array(float64, 1d, A))
Known signatures:
* (array(float64, 1d, C), array(float64, 1d, C)) -> array(float64, 1d, C)
...
self.c = cumutrapz(self.a, self.b)
^
因此 class 中的数组不连续。
为了确保它们是,您可以将 class 规范更改为:
@jitclass([
('a', nb.float64[::1]),
('b', nb.float64[::1]),
('c', nb.float64[::1]),
])
现在可以使用了(使用 Numba 0.54.0 测试)。
@njit
def cumutrapz(x:np.array, y:np.array):
return np.append(0, [
np.trapz(y=y[i-2:i], x=x[i-2:i]) for i in range(2, len(x) + 1)]).cumsum()
from numba import float64
@jitclass([
('a', float64[:]),
('b', float64[:]),
('c', float64[:]),
])
class Testaroo(object):
def __init__(self, a, b):
self.a = a
self.b = b
self.c = np.zeros(len(self.a), dtype=np.float64)
def set_c(self):
self.c = cumutrapz(self.a, self.b)
testaroo = Testaroo(
np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
testaroo.set_c()
以上失败,但以下两个非常相似的示例有效:
cumutrapz(np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
和
from numba import float64
@jitclass([
('a', float64[:]),
('b', float64[:]),
('c', float64[:]),
])
class Testaroo(object):
def __init__(self, a, b):
self.a = a
self.b = b
self.c = np.zeros(len(self.a), dtype=np.float64)
def set_c(self):
self.c = (self.a * self.b).cumsum()
testaroo = Testaroo(
np.arange(50, dtype=np.float64), np.sin(np.arange(50, dtype=np.float64)))
testaroo.set_c()
后一个示例现在对我有用,但我想知道是否有办法让 cumutrapz
函数在 jitclass
.
我正在使用 numba 版本“0.53.1”。
仔细阅读长长的错误信息,您会发现:
No implementation of function Function(<function trapz at 0x7f7e9b21e5e0>)
found for signature:
>>> trapz(y=array(float64, 1d, A), x=array(float64, 1d, A))
...
reshape() supports contiguous array only
具有 format A (any) 的数组不一定是连续的。
您可以确保该函数仅处理 contiguous arrays:
@njit([nb.float64[::1](nb.float64[::1], nb.float64[::1])])
def cumutrapz(x, y):
...
然后出现新的错误:
Invalid use of type(CPUDispatcher(<function MyTestCase.test_cumutrapz.<locals>.cumutrapz at 0x7f8e15a841f0>))
with parameters (array(float64, 1d, A), array(float64, 1d, A))
Known signatures:
* (array(float64, 1d, C), array(float64, 1d, C)) -> array(float64, 1d, C)
...
self.c = cumutrapz(self.a, self.b)
^
因此 class 中的数组不连续。
为了确保它们是,您可以将 class 规范更改为:
@jitclass([
('a', nb.float64[::1]),
('b', nb.float64[::1]),
('c', nb.float64[::1]),
])
现在可以使用了(使用 Numba 0.54.0 测试)。