使用 scipy 拟合微分方程,得到 "object too deep for desired array"

Fit a differential equation using scipy, getting "object too deep for desired array"

我正在尝试将曲线拟合到微分方程。为了简单起见,我只是在这里做逻辑方程。我写了下面的代码,但下面显示了一个错误。我不太确定我做错了什么。

import numpy as np
import pandas as pd
import scipy.optimize as optim
from scipy.integrate import odeint

df_yeast = pd.DataFrame({'cd': [9.6, 18.3, 29., 47.2, 71.1, 119.1, 174.6, 257.3, 350.7, 441., 513.3, 559.7, 594.8, 629.4, 640.8, 651.1, 655.9, 659.6], 'td': np.arange(18)})

N0 = 1
parsic = [5, 2]

def logistic_de(t, N, r, K):
    return r*N*(1 - N/K)

def logistic_solution(t, r, K):
    return odeint(logistic_de, N0, t, (r, K))

params, _ = optim.curve_fit(logistic_solution, df_yeast['td'], df_yeast['cd'], p0=parsic);
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
ValueError: object too deep for desired array

---------------------------------------------------------------------------
error                                     Traceback (most recent call last)
<ipython-input-94-2a5a467cfa43> in <module>
----> 1 params, _ = optim.curve_fit(logistic_solution, df_yeast['td'], df_yeast['cd'], p0=parsic);

~/SageMath/local/lib/python3.9/site-packages/scipy/optimize/minpack.py in curve_fit(f, xdata, ydata, p0, sigma, absolute_sigma, check_finite, bounds, method, jac, **kwargs)
    782         # Remove full_output from kwargs, otherwise we're passing it in twice.
    783         return_full = kwargs.pop('full_output', False)
--> 784         res = leastsq(func, p0, Dfun=jac, full_output=1, **kwargs)
    785         popt, pcov, infodict, errmsg, ier = res
    786         ysize = len(infodict['fvec'])

~/SageMath/local/lib/python3.9/site-packages/scipy/optimize/minpack.py in leastsq(func, x0, args, Dfun, full_output, col_deriv, ftol, xtol, gtol, maxfev, epsfcn, factor, diag)
    420         if maxfev == 0:
    421             maxfev = 200*(n + 1)
--> 422         retval = _minpack._lmdif(func, x0, args, full_output, ftol, xtol,
    423                                  gtol, maxfev, epsfcn, factor, diag)
    424     else:

error: Result from function call is not a proper array of floats.

@hpaulj 指出了 logistic_solution 中 return 值的形状问题,并表明修复消除了您报告的错误。

但是,代码中还有另一个问题。该问题不会产生错误,但会导致您的测试问题(逻辑微分方程)的解不正确。默认情况下,odeint 期望计算微分方程的函数的 t 参数是 second 参数。要么更改 logistic_de 的前两个参数的顺序,要么将参数 tfirst=True 添加到 odeint 的调用中。第二个选项更好一点,因为如果您决定尝试使用 logistic_descipy.integrate.solve_ivp 而不是 odeint.

,它将允许您使用 logistic_descipy.integrate.solve_ivp

logistic_solution 的样本 运行 产生 (18,1) 结果:

In [268]: logistic_solution(df_yeast['td'], *parsic)
Out[268]: 
array([[ 1.00000000e+00],
       [ 2.66666671e+00],
       [ 4.33333337e+00],
       [ 1.00000004e+00],
       [-1.23333333e+01],
       [-4.06666666e+01],
       [-8.90000000e+01],
       [-1.62333333e+02],
       [-2.65666667e+02],
       [-4.04000000e+02],
       [-5.82333333e+02],
       [-8.05666667e+02],
       [-1.07900000e+03],
       [-1.40733333e+03],
       [-1.79566667e+03],
       [-2.24900000e+03],
       [-2.77233333e+03],
       [-3.37066667e+03]])
In [269]: _.shape
Out[269]: (18, 1)

y 值是

In [281]: df_yeast['cd'].values.shape
Out[281]: (18,)

定义一个returns一维数组的替代函数:

In [282]: def foo(t,r,K):
     ...:     return logistic_solution(t,r,K).ravel()

这个有效:

In [283]: params, _ = optim.curve_fit(foo, df_yeast['td'], df_yeast['cd'], p0=parsic)
In [284]: params
Out[284]: array([16.65599815, 15.52779946])

测试 params:

In [287]: logistic_solution(df_yeast['td'], *params)
Out[287]: 
array([[  1.        ],
       [  8.97044688],
       [ 31.45157847],
       [ 66.2980814 ],
       [111.36464226],
       [164.50594767],
       [223.5766842 ],
       [286.43153847],
       [350.92519706],
       [414.91234658],
       [476.24767362],
       [532.78586477],
       [582.38160664],
       [622.88958582],
       [652.1644889 ],
       [668.0610025 ],
       [668.43381319],
       [651.13760758]])
In [288]: df_yeast['cd'].values
Out[288]: 
array([  9.6,  18.3,  29. ,  47.2,  71.1, 119.1, 174.6, 257.3, 350.7,
       441. , 513.3, 559.7, 594.8, 629.4, 640.8, 651.1, 655.9, 659.6])

too deep 在此上下文中表示二维数组,当它应该是一维时,以便与 ydata

进行比较
ydata : array_like
    The dependent data, a length M array - nominally ``f(xdata, ...)``.