当从 ODE 求解器调用非线性方程求解器时,fsolve mismatch shape error
fsolve mismatch shape error when nonlinear equations solver called from ODE solver
我的函数 "par_impl(y)" 中有一个由两个非线性方程组成的系统,我可以使用 scipy.optimize.root 独立求解。这里 "y" 是一个参数。
但我希望从 ODE 求解器 odeint 调用该系统,以便针对 "y" 的不同值求解它,并与简单的 ODE 耦合。
这给了我 fsolve 不匹配形状错误。
import numpy as np
from scipy.optimize import root import
matplotlib.pyplot as plt
from scipy.integrate import odeint
def par_impl(y):
def functionh(x):
return [y + (x[0]**2) - x[1] -23, -4 - y*x[0] + (x[1]**2)]
sol = root(functionh, [1, 1])
return sol.x
def dy_dt(y, t):
dydt = (y**0.5) + par_impl(y)[0]
return dydt
ls = np.linspace(0, 2, 50) y_0 = 2
Ps = odeint(dy_dt, y_0, ls)
y = Ps[:,0]
plt.plot(ls, y, "+", label="X") plt.legend(); plt.figure()
我得到的错误是:
File
"C:\Users\matteo\AppData\Local\Continuum\anaconda3\lib\site-packages\scipy\optimize\minpack.py",
line 41, in _check_func
raise TypeError(msg)
TypeError: fsolve: there is a mismatch between the input and output
shape of the 'func' argument 'functionh'.Shape should be (2,) but it
is (2, 1).
问题是 y
是您代码中 len
= 1
的列表。要访问其元素,您需要在函数中使用 y[0]
。下面是带有输出图的代码的工作版本(未显示整个代码)。
from scipy.optimize import root
from scipy.integrate import odeint
# Other plotting and numpy imports
def par_impl(y):
def functionh(x):
return [y[0] + (x[0]**2) - x[1] -23, -4 - y[0]*x[0] + (x[1]**2)] # y --> y[0]
sol = root(functionh, ([1, 1]))
return sol.x
# dy_dt function here
# Rest of your code unchanged
plt.plot(ls, y, "+", label="X")
plt.legend()
输出
我的函数 "par_impl(y)" 中有一个由两个非线性方程组成的系统,我可以使用 scipy.optimize.root 独立求解。这里 "y" 是一个参数。 但我希望从 ODE 求解器 odeint 调用该系统,以便针对 "y" 的不同值求解它,并与简单的 ODE 耦合。 这给了我 fsolve 不匹配形状错误。
import numpy as np
from scipy.optimize import root import
matplotlib.pyplot as plt
from scipy.integrate import odeint
def par_impl(y):
def functionh(x): return [y + (x[0]**2) - x[1] -23, -4 - y*x[0] + (x[1]**2)] sol = root(functionh, [1, 1]) return sol.x
def dy_dt(y, t):
dydt = (y**0.5) + par_impl(y)[0] return dydt
ls = np.linspace(0, 2, 50) y_0 = 2
Ps = odeint(dy_dt, y_0, ls)
y = Ps[:,0]
plt.plot(ls, y, "+", label="X") plt.legend(); plt.figure()
我得到的错误是:
File "C:\Users\matteo\AppData\Local\Continuum\anaconda3\lib\site-packages\scipy\optimize\minpack.py", line 41, in _check_func raise TypeError(msg)
TypeError: fsolve: there is a mismatch between the input and output shape of the 'func' argument 'functionh'.Shape should be (2,) but it is (2, 1).
问题是 y
是您代码中 len
= 1
的列表。要访问其元素,您需要在函数中使用 y[0]
。下面是带有输出图的代码的工作版本(未显示整个代码)。
from scipy.optimize import root
from scipy.integrate import odeint
# Other plotting and numpy imports
def par_impl(y):
def functionh(x):
return [y[0] + (x[0]**2) - x[1] -23, -4 - y[0]*x[0] + (x[1]**2)] # y --> y[0]
sol = root(functionh, ([1, 1]))
return sol.x
# dy_dt function here
# Rest of your code unchanged
plt.plot(ls, y, "+", label="X")
plt.legend()
输出