如何在仅包含 Numpy 数组的程序中使用 numba 的 jit?

How can I use numba's jit in my program which contains only Numpy arrays?

我的程序计算求解线性微分方程时的误差。它仅使用 numpy 数组。当我尝试对我定义的函数使用 numba 的 jit 装饰器时,我只是收到错误。能帮我正确使用吗?

我的代码:

import numpy as np
from numba import jit

def rk4(t_prev, x_prev, derivs, dt):
    k1 = dt * derivs(t_prev, x_prev)
    k2 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k1)
    k3 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k2)
    k4 = dt * derivs(t_prev + dt, x_prev + k3)
    x_next = x_prev + 1/6*k1 + 1/3*k2 + 1/3*k3 + 1/6*k4
    return x_next

global k, x_0, v_0, t_0, t_f

k = 1

x_0 = 0
v_0 = np.sqrt(k)

t_0 = 0
t_f = 10

dtList = np.logspace(0, -5, 1000)


def derivs(t, X):
    deriv = np.zeros([2])
    deriv[0] = X[1]
    deriv[1] = -k * X[0]
    return deriv


def err(dt):
    tList = np.arange(t_0, t_f + dt, dt)
    N = tList.shape[0]
    XList = np.zeros([N,2])
    XList[0][0], XList[0][1] = x_0, v_0
    for i in range(N-1):
        XList[i+1] = rk4(tList[i], XList[i], derivs, dt)
    error = np.abs(XList[-1][0] - np.sin(10))
    return error

print(err(.001))

以下对我有用:

import numpy as np
from numba import jit

@jit(nopython=True)
def rk4(t_prev, x_prev, derivs, dt):
    k1 = dt * derivs(t_prev, x_prev)
    k2 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k1)
    k3 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k2)
    k4 = dt * derivs(t_prev + dt, x_prev + k3)
    x_next = x_prev + 1/6*k1 + 1/3*k2 + 1/3*k3 + 1/6*k4
    return x_next

global k, x_0, v_0, t_0, t_f

k = 1

x_0 = 0
v_0 = np.sqrt(k)

t_0 = 0
t_f = 10

dtList = np.logspace(0, -5, 1000)

@jit(nopython=True)
def derivs(t, X):
    deriv = np.zeros(2)
    deriv[0] = X[1]
    deriv[1] = -k * X[0]
    return deriv


@jit(nopython=True)
def err(dt):
    tList = np.arange(t_0, t_f + dt, dt)
    N = tList.shape[0]
    XList = np.zeros((N,2))
    XList[0][0], XList[0][1] = x_0, v_0
    for i in range(N-1):
        XList[i+1] = rk4(tList[i], XList[i], derivs, dt)
    error = np.abs(XList[-1][0] - np.sin(10))
    return error

print(err(.001))

请注意,我对您的代码所做的唯一两处更改是将对 np.zeros 的调用替换为在 2d 情况下将列表传递给 tuple 或仅在第一种情况。请参阅以下问题以了解为什么会这样:

https://github.com/numba/numba/issues/3993