使用 Numba 加速 Python 代码时引发 TypeError

TypeError raised when using Numba to accelerate Python code

我正在使用 Numba 的 @jit 装饰器通过 nopython=True 选项优化我的函数。当我删除这个选项时,我的代码运行良好,但我读到启用该选项将显着加快代码速度。

但是,我收到以下回溯和错误:

Traceback (most recent call last):
  File "C:\Users\dis_YO_boi\Documents\Programming\Python\Base3DSolver4.py", line 83, in <module>
    soln = odeint(f, y0, time, mxstep = 5000)
  File "C:\Anaconda3\lib\site-packages\scipy\integrate\odepack.py", line 153, in odeint
    ixpr, mxstep, mxhnil, mxordn, mxords)
  File "C:\Anaconda3\lib\site-packages\numba\dispatcher.py", line 172, in _compile_for_args
    return self.compile(sig)
  File "C:\Anaconda3\lib\site-packages\numba\dispatcher.py", line 350, in compile
    flags=flags, locals=self.locals)
  File "C:\Anaconda3\lib\site-packages\numba\compiler.py", line 644, in compile_extra
    return pipeline.compile_extra(func)
  File "C:\Anaconda3\lib\site-packages\numba\compiler.py", line 361, in compile_extra
    return self.compile_bytecode(bc, func_attr=self.func_attr)
  File "C:\Anaconda3\lib\site-packages\numba\compiler.py", line 370, in compile_bytecode
    return self._compile_bytecode()
  File "C:\Anaconda3\lib\site-packages\numba\compiler.py", line 631, in _compile_bytecode
    return pm.run(self.status)
  File "C:\Anaconda3\lib\site-packages\numba\compiler.py", line 251, in run
    raise patched_exception
  File "C:\Anaconda3\lib\site-packages\numba\compiler.py", line 243, in run
    res = stage()
  File "C:\Anaconda3\lib\site-packages\numba\compiler.py", line 458, in stage_nopython_frontend
    self.locals)
  File "C:\Anaconda3\lib\site-packages\numba\compiler.py", line 759, in type_inference_stage
    infer.propagate()
  File "C:\Anaconda3\lib\site-packages\numba\typeinfer.py", line 510, in propagate
    raise errors[0]
numba.errors.TypingError: Failed at nopython (nopython frontend)
Internal error at <numba.typeinfer.SetItemConstraint object at 0x0000000007C082B0>:
--%<-----------------------------------------------------------------
Traceback (most recent call last):
  File "C:\Anaconda3\lib\site-packages\numba\typeinfer.py", line 111, in propagate
    constraint(typeinfer)
  File "C:\Anaconda3\lib\site-packages\numba\typeinfer.py", line 377, in __call__
    index=it, value=vt):
  File "C:\Anaconda3\lib\site-packages\numba\typing\context.py", line 149, in resolve_setitem
    return self.resolve_function_type("setitem", args, kws)
  File "C:\Anaconda3\lib\site-packages\numba\typing\context.py", line 97, in resolve_function_type
    res = defn.apply(args, kws)
  File "C:\Anaconda3\lib\site-packages\numba\typing\templates.py", line 155, in apply
    sig = generic(args, kws)
  File "C:\Anaconda3\lib\site-packages\numba\typing\arraydecl.py", line 158, in generic
    raise TypeError("Cannot modify value of type %s" %(ary,))
TypeError: Cannot modify value of type readonly array(float64, 2d, C)
--%<-----------------------------------------------------------------

File "Base3DSolver4.py", line 61

这是什么意思,我该如何解决?

这是我的代码:

import numpy as np
from scipy.integrate import odeint
from numba import jit

Kvec = np.logspace(-2, 2, 101, 10)
lenK = len(Kvec)
Qvec = 2*np.logspace(-4, 1, 101, 10)
lenQ = len(Qvec)
y0 = [0] * (15)

S0plot = np.zeros((lenK,lenQ))
S1plot = np.zeros((lenK,lenQ))
S2plot = np.zeros((lenK,lenQ))
S3plot = np.zeros((lenK,lenQ))
S4plot = np.zeros((lenK,lenQ))
KS0plot = np.zeros((lenK,lenQ))
KS1plot = np.zeros((lenK,lenQ))
KS2plot = np.zeros((lenK,lenQ))
KS3plot = np.zeros((lenK,lenQ))
PS1plot = np.zeros((lenK,lenQ))
PS2plot = np.zeros((lenK,lenQ))
PS3plot = np.zeros((lenK,lenQ))
PS4plot = np.zeros((lenK,lenQ))
Kplot = np.zeros((lenK,lenQ))
Pplot = np.zeros((lenK,lenQ))

ydot = np.zeros((15,1))

for Qloop in range(lenQ):
    for Kloop in range(lenK):
        K0 = Kvec[Kloop]
        Q = Qvec[Qloop]
        r1 = 2e-5
        r2 = 2e-4
        a = 0.001
        d = 0.001
        k = 0.999
        P0 = 1
        tf = 1e10
        time = np.linspace(0, tf, 1001)
        S00 = Q/r1

        @jit(nopython=True)
        def f(y, t):
            S0 = y[0]
            S1 = y[1]
            S2 = y[2]
            S3 = y[3]
            S4 = y[4]
            KS0 = y[5]
            KS1 = y[6]
            KS2 = y[7]
            KS3 = y[8]
            PS1 = y[9]
            PS2 = y[10]
            PS3 = y[11]
            PS4 = y[12]
            K = y[13]
            P = y[14]

            ydot[0] = Q-r1*S0+d*KS0-a*K*S0+k*(PS1+PS2+PS3+PS4)
            ydot[1] = k*KS0+d*(PS1+KS1)-S1*(r1+a*K+a*P)
            ydot[2] = k*KS1+d*(PS2+KS2)-S2*(r1+a*K+a*P)
            ydot[3] = k*KS2+d*(PS3+KS3)-S3*(r1+a*K+a*P)
            ydot[4] = k*KS3+d*PS4-S4*(r2+a*P)
            ydot[5] = a*K*S0-(d+k+r1)*KS0
            ydot[6] = a*K*S1-(d+k+r1)*KS1
            ydot[7] = a*K*S2-(d+k+r1)*KS2
            ydot[8] = a*K*S3-(d+k+r1)*KS3
            ydot[9] = a*P*S1-(d+k+r1)*PS1
            ydot[10] = a*P*S2-(d+k+r1)*PS2
            ydot[11] = a*P*S3-(d+k+r1)*PS3
            ydot[12] = a*P*S4-(d+k+r2)*PS4
            ydot[13] = (d+k+r1)*(KS0+KS1+KS2+KS3)-a*K*(S0+S1+S2+S3)
            ydot[14] = (d+k+r1)*(PS1+PS2+PS3)+(d+k+r2)*PS4-a*P*(S1+S2+S3+S4)

            return ydot[:,0]

        y0[0] = S00
        y0[13] = K0
        y0[14] = P0

        soln = odeint(f, y0, time, mxstep = 5000)

        S0 = soln[:,0]
        S1 = soln[:,1]
        S2 = soln[:,2]
        S3 = soln[:,3]
        S4 = soln[:,4]
        KS0 = soln[:,5]
        KS1 = soln[:,6]
        KS2 = soln[:,7]
        KS3 = soln[:,8]
        PS1 = soln[:,9]
        PS2 = soln[:,10]
        PS3 = soln[:,11]
        PS4 = soln[:,12]
        K = soln[:,13]
        P = soln[:,14]

        S0plot[Kloop,Qloop] = soln[len(time)-1,0]
        S1plot[Kloop,Qloop] = soln[len(time)-1,1]
        S2plot[Kloop,Qloop] = soln[len(time)-1,2]
        S3plot[Kloop,Qloop] = soln[len(time)-1,3]
        S4plot[Kloop,Qloop] = soln[len(time)-1,4]
        KS0plot[Kloop,Qloop] = soln[len(time)-1,5]
        KS1plot[Kloop,Qloop] = soln[len(time)-1,6]
        KS2plot[Kloop,Qloop] = soln[len(time)-1,7]
        KS3plot[Kloop,Qloop] = soln[len(time)-1,8]
        PS1plot[Kloop,Qloop] = soln[len(time)-1,9]
        PS2plot[Kloop,Qloop] = soln[len(time)-1,10]
        PS3plot[Kloop,Qloop] = soln[len(time)-1,11]
        PS4plot[Kloop,Qloop] = soln[len(time)-1,12]
        Kplot[Kloop,Qloop] = soln[len(time)-1,13]
        Pplot[Kloop,Qloop] = soln[len(time)-1,14]
        Stot = S0plot + S1plot + S2plot + S3plot + S4plot + KS1plot + \
               KS2plot + KS3plot + PS1plot + PS2plot + PS3plot + PS4plot

        Smod = S4plot

print(Stot)
print(Smod)

感谢您的帮助。

您正在尝试从已在 nopython 模式下 JIT 的 Numba 函数就地修改全局变量(即 ydot)。 This is not supported。要么改变一些东西,使 ydot 成为函数的参数,要么不要在 nopython 模式下编译。