Scipy 优化 - 缺少 2 个必需的位置参数

Scipy Optimize - missing 2 required positional arguments

我在尝试最小化非常简单的代码时遇到了这个问题。

from scipy.optimize import fsolve
from scipy.optimize import minimize

design_var_0 = (0.1, 500, 0.5)

def height_equation(h, D, n, m_batt):
    return return h - D*n*m_batt

def height(design_var):
    return fsolve(height_equation, x0=10000, args=design_var)[0]

def respro_height(design_var):
    return 1/height(design_var)

minimize(respro_height, design_var_0)
  File "solver.py", line 36, in height
    return fsolve(height_equation, x0=10000, args=design_var)[0]
  File "C:\Users\kj\Anaconda3\lib\site-packages\scipy\optimize\minpack.py", line 160, in fsolve
    res = _root_hybr(func, x0, args, jac=fprime, **options)
  File "C:\Users\kj\Anaconda3\lib\site-packages\scipy\optimize\minpack.py", line 226, in _root_hybr
    shape, dtype = _check_func('fsolve', 'func', func, x0, args, n, (n,))
  File "C:\Users\kj\Anaconda3\lib\site-packages\scipy\optimize\minpack.py", line 24, in _check_func
    res = atleast_1d(thefunc(*((x0[:numinputs],) + args)))
TypeError: height_equation() missing 2 required positional arguments: 'n' and 'm_batt'

所以似乎 内部优化循环 和 fsolve 给了我这个错误。 但是,如果我只是用元组调用 respro_height,我会按预期得到浮点数答案。

>>> respro_height(design_var_0)
0.0012771435682421253

有人可以解释一下,为什么我会收到此错误以及如何修复它?

请注意,函数height_equation指的是解析函数,其形式为lambert-W,不能直接求解高度。这就是为什么我使用 fsolve 来获取高度

当通过 minimize 调用时,height() 中的 deisgn_var 参数作为数组而不是元组传递。我对 minimize 的内部工作原理不够熟悉,无法理解原因。但是下面的小修改,将参数显式转换为元组,应该可以修复它

def height(design_var):
    return fsolve(height_equation, x0=10000, args=tuple(design_var))[0]

输出

     fun: 4.363615312479626e-05
 hess_inv: array([[6.05302726e+04, 6.37807548e+01, 5.21903776e+04],
       [6.37807548e+01, 1.06720689e+00, 5.49938958e+01],
       [5.21903776e+04, 5.49938958e+01, 4.50012847e+04]])
      jac: array([-6.05549531e-06, -8.72714736e-08, -6.86066005e-06])
  message: 'Optimization terminated successfully.'
     nfev: 64
      nit: 14
     njev: 16
   status: 0
  success: True
        x: array([  7.20604199, 500.00712565,   6.36034335])