使用 numba 加速 odeint,尝试传递字典和自定义对象时出现问题

Using numba to accelerate odeint, issues when trying to pass a dictionary and custom object

我正在做一个个人项目,在 Python 中对四轴飞行器模拟(和控制)进行编码,作为一个学习项目。我正在使用 scipy 积分器 odeint 并且我对长计算时间感到非常失望。所以我希望使用 numba 来加速我的集成。我在每个时间步都调用 odeint,因为我必须在每个模拟时间步之后创建命令。

起初,当我要集成的函数 (state_dot) 是 Quadcopter class 的方法时,我遇到了问题。所以我把它作为一个单独的函数,但是当我用 @jit 装饰我的函数时,我现在在定义正确的类型时遇到了问题。 state_dot 函数有一个字典 (params) 作为输入参数(我读过 numba 支持字典),但也是一个自定义的 class (wind),因为我的风模型就是那个class的一个方法。如果我暂时排除 wind,使用 numba.typed.Dict 似乎无法导入字典。

要在函数中导入 wind 对象,我看到使用了 numba 类型 object_,但是 Python 在中找不到 object_ numba.

我使用的是 numba 0.45.0 版和 Python 3.7.

import numpy as np
from scipy.integrate import odeint
from numba import jit, void, float_, int_
import numba

class Quadcopter:

    def __init__(self):

        # Quad Params
        # ---------------------------
        mB  = 1.2       # mass (kg)
        params = {}
        params["mB"]   = mB
        self.params = params


        # Initial State
        # ---------------------------
        self.state = np.zeros(3)

    def update(self, t, Ts, cmd, wind):

        self.state = odeint(state_dot, self.state, [t,t+Ts], args = (cmd, self.params, wind))[1]


@jit(void(float_[:], float_, float_[:], numba.typed.Dict )) #(nopython = True)
def state_dot(state, t, cmd, params, wind):

    # Import Params
    # ---------------------------    
    mB   = params["mB"]

    # Import State Vector
    # ---------------------------  
    x      = state[0]
    y      = state[1]
    z      = state[2]

    # Motor Dynamics and Rotor forces (Second Order System: https://apmonitor.com/pdc/index.php/Main/SecondOrderSystems)
    # ---------------------------
    print(cmd)

    # Wind Model
    # ---------------------------
    [velW, qW1, qW2] = wind.randomWind(t)
    print(velW)

    # State Derivative Vector
    # ---------------------------
    sdot     = np.zeros(3)
    sdot[0]  = x*t + 0.1
    sdot[1]  = y*t + 0.1
    sdot[2]  = z*t + 0.1


    return sdot


class Wind:

    def __init__(self):

        # Normally, average wind would be randomly set here
        self.velW_med = 5.0
        self.qW1_med  = 0.2
        self.qW2_med  = 0.1

    def randomWind(self, t):

        # Normally, wind values would be a sine function dependant of current time
        velW = self.velW_med
        qW1  = self.qW1_med
        qW2  = self.qW2_med

        return velW, qW1, qW2

# Set time
Ti = 0
Ts = 0.005
Tf = 10

# Initialize quadcopter and wind
quad = Quadcopter()
wind = Wind()

# Simulation
t = Ti
while round(t,3) < Tf:
    cmd = np.array([1,2,1,3])
    quad.update(t, Ts, cmd, wind)
    print(quad.state)
    t += Ts

收到的错误是

Traceback (most recent call last):
  File "c:/Users/JOHN-Laptop/Documents/Code Dev/Test/question_quad.py", line 29, in <module>
    @jit(void(float_[:], float_, float_[:], numba.typed.Dict )) #(nopython = True)
  File "C:\Users\JOHN-Laptop\AppData\Local\Programs\Python\Python37\lib\site-packages\numba\decorators.py", line 186, in wrapper
    disp.compile(sig)
  File "C:\Users\JOHN-Laptop\AppData\Local\Programs\Python\Python37\lib\site-packages\numba\compiler_lock.py", line 32, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "C:\Users\JOHN-Laptop\AppData\Local\Programs\Python\Python37\lib\site-packages\numba\dispatcher.py", line 676, in compile
    args, return_type = sigutils.normalize_signature(sig)
  File "C:\Users\JOHN-Laptop\AppData\Local\Programs\Python\Python37\lib\site-packages\numba\sigutils.py", line 48, in normalize_signature
    check_type(ty)
  File "C:\Users\JOHN-Laptop\AppData\Local\Programs\Python\Python37\lib\site-packages\numba\sigutils.py", line 43, in check_type
    "instance, got %r" % (ty,))
TypeError: invalid type in signature: expected a type instance, got <class 'numba.typed.typeddict.Dict'>

完整代码可以在这里查看:https://github.com/bobzwik/Quadcopter_SimCon/blob/dev_numba/Simulation/quadFiles/quad.py

如果我遗漏任何信息,请随时询问。

编辑:将完整代码的 link 更改为 link 到另一个分支。

我注意到的第一件事是——至少在你在这里展示的代码中——你的 jit 签名有四种类型,但你正在装饰的函数有五个参数:

@jit(void(float_[:], float_, float_[:], numba.typed.Dict))
def state_dot(state, t, cmd, params, wind):

很明显你需要解决这个问题。最简单的尝试就是删除签名并让 numba 解决:

@jit
def state_dot(state, t, cmd, params, wind):

当然,即使您这样做,numba 仍然抱怨它不知道如何键入所有内容,并指向 mB = params["mB"] 行。它仍然可以 "loop lifting",这意味着它可以编译一些东西,但不会尽可能快。

所以第二件要注意的事情是,虽然 numba 说它支持 dicts,但随后又提出了很多警告。基本上,使用字典仍然不是一个好主意。我也没有看到你使用 dict 的任何充分理由。为什么不像 self.mB = mB 那样让 mB 成为您的 class 的成员?我知道你的完整 Quadcopter class 中会有更复杂的东西,但你可以有很多成员。

现在,要注意的第三件事是,自从我编写 that gist you pointed out elsewhere, and can now handle classes, so you might want to look into numba.jitclass 以来,numba 已经变得更好了。通常,当您将 jitclass 对象传递给您尝试 jit 的函数时,numba 将知道如何处理它。

但也许比所有这些更重要的是您的 update 方法会为每个步骤调用 odeint。我猜这是你代码中最慢的部分。该函数应该被调用一次,以便它可以从头到尾解决您的整个问题,因此它有很多(相对较慢的)开销与理解您传递的参数、分配内存、初始化事物等相关. 一个更好的方法是构造一个 scipy.integrate.ode object to keep everything set up between steps — and keep it around so that you can use the same one between steps. The newer interfaces solve_ivp and RK45 (和类似的)分别大致相当于 odeintode,除了 ode 有我喜欢的求解器 dop853。如果您只需要 OdeSolver subclass 之一,您可能更喜欢这些接口。另请注意,如果您实际上在 步骤之间更改了状态中的任何内容,您可能需要再次调用 set_initial_value,否则可能会在您不注意的情况下出错。

更一般地说,如果您担心速度,您可以做的最好的事情就是分析您的代码。这里的第一步是在 ipython.

中使用 %prun