当 python class jitclass 本身包含 jitclass classes 时,如何使其兼容?

How make a python class jitclass compatible when it contains itself jitclass classes?

我正在尝试制作一个 class,它可能是 jitclass 的一部分,但具有一些自身 jitclass 对象的属性。

例如,如果我有两个 class 和装饰器 @jitclass,我想在第三个 class (combined) 中实例化它们。

import numpy as np
from numba import jitclass
from numba import boolean, int32, float64,uint8

spec = [
    ('type' ,int32),
    ('val' ,float64[:]),
    ('result',float64)]

@jitclass(spec)
class First:
    def __init__(self):
        self.type = 1
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)

@jitclass(spec)
class Second:
    def __init__(self):
        self.type = 2
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)



@jitclass(spec)
class Combined:
    def __init__(self):
        self.List = []
        for i in range(10):
            self.List.append(First())
            self.List.append(Second())

    def sum(self):
        for i, c in enumerate(self.List):
            c.sum()
    def getresult(self):
        result = []
        for i, c in enumerate(self.List):
            result.append(c.result)
        return result


C = Combined()
C.sum()
result = C.getresult()
print(result)

在那个例子中我得到一个错误,因为 numba 无法确定 self.List 的类型,它是两个 jitclasses.

的组合

如何使 class Combinedjitclass 兼容?

更新

它尝试了我在别处找到的东西:

import numpy as np
from numba import jitclass, deferred_type
from numba import boolean, int32, float64,uint8
from numba.typed import List

spec = [
    ('type' ,int32),
    ('val' ,float64[:]),
    ('result',float64)]

@jitclass(spec)
class First:
    def __init__(self):
        self.type = 1
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)
 

 
spec1 = [('ListA',  List(First.class_type.instance_type, reflected=True))]

@jitclass(spec1)
class Combined:
    def __init__(self):
        self.ListA = [First(),First()] 

    def sum(self):
        for i, c in enumerate(self.ListA):
            c.sum()
    def getresult(self):
        result = []
        for i, c in enumerate(self.ListA):
            result.append(c.result)
        return result


C = Combined()
C.sum()
result = C.getresult()
print(result)

但是我得到这个错误

List(First.class_type.instance_type)
TypeError: __init__() takes 1 positional argument but 2 were given

长话短说:

  • 您可以在 jitclass 中引用其他 jitclass 内容,即使您有这些内容的列表。您只需要更正命名空间 numba.typed -> numba.types
  • 目前(从 numba 0.46 开始)在 jitclasses 或 no-python numba.jit 函数中不可能有异构列表。所以你不能在同一个列表中附加 FirstSecond 的两个实例。

解决 numba.typed.List 异常

您的更新几乎是正确的。您需要使用 numba.types.List 而不是 numba.typed.List。区别有点微妙,但是 numba.types 包含签名类型,而 numba.typed 命名空间包含 类 可以在代码中实例化和使用。

所以如果你使用它会起作用:

spec1 = [('ListA',  nb.types.List(First.class_type.instance_type, reflected=True))]

更改此代码:

import numpy as np
import numba as nb

spec = [
    ('type', nb.int32),
    ('val', nb.float64[:]),
    ('result', nb.float64)
]

@nb.jitclass(spec)
class First:
    def __init__(self):
        self.type = 1
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)

spec1 = [('ListA',  nb.types.List(First.class_type.instance_type, reflected=True))]

@nb.jitclass(spec1)
class Combined:
    def __init__(self):
        self.ListA = [First(), First()] 
    def sum(self):
        for i, c in enumerate(self.ListA):
            c.sum()
    def getresult(self):
        result = []
        for i, c in enumerate(self.ListA):
            result.append(c.result)
        return result

C = Combined()
C.sum()
result = C.getresult()
print(result)

产生输出:[100.0, 100.0].

间奏曲:在这里使用jitclass有意义吗?

然而,这里要记住的是,正常的 Python 类 可能比 jitclass-方法(或同样快)更快:

import numpy as np
import numba as nb

class First:
    def __init__(self):
        self.type = 1
        self.val = np.ones(100)
        self.result = 0.
    def sum(self):
        self.result = np.sum(self.val)

class Combined:
    def __init__(self):
        self.ListA = [First(), First()] 
    def sum(self):
        for i, c in enumerate(self.ListA):
            c.sum()
    def getresult(self):
        result = []
        for i, c in enumerate(self.ListA):
            result.append(c.result)
        return result

C = Combined()
C.sum()
C.getresult()

如果这只是出于好奇,那没问题。但是对于生产,我会从纯 Python+NumPy 开始,只在 numba 太慢时才应用它,然后只在瓶颈部分应用,并且只有 numba 擅长优化这些东西(numba 是专门的工具moment,不是通用工具)。

具有 numba 的异构(混合类型)列表?

在 no-python(无对象)模式下使用 numba,您需要同类列表。据我所知,numba 0.46 不支持在 jit类 或 nopython-jit 方法中包含不同类型对象的列表。这意味着您不能有一个包含 FirstSecond 个实例的列表。

所以这行不通:

self.List.append(First())
self.List.append(Second())

来自numba docs

Creating and returning lists from JIT-compiled functions is supported, as well as all methods and operations. Lists must be strictly homogeneous: Numba will reject any list containing objects of different types, even if the types are compatible [...]