从另一个文件导入时,我在为 scipy 的 leastsq 优化实现残差函数时遇到问题

I have problem implementing residuals function for leastsq optimization of scipy when importing it from another file

我写了一段函数相互调用的代码。工作代码如下:

import numpy as np
from scipy.optimize import leastsq
import RF

func = RF.roots
# residuals = RF.residuals

def residuals(params, x, y):
    return y - func(params, x)

def estimation(x, y):
    p_guess = [1, 2, 0.5, 0]
    params, cov, infodict, mesg, ier = leastsq(residuals, p_guess, args=(x, y), full_output=True)
    return params

x = np.array([2.78e-03, 3.09e-03, 3.25e-03, 3.38e-03, 3.74e-03, 4.42e-03, 4.45e-03, 4.75e-03, 8.05e-03, 1.03e-02, 1.30e-02])
y = np.array([2.16e+02, 2.50e+02, 3.60e+02, 4.48e+02, 5.60e+02, 8.64e+02, 9.00e+02, 1.00e+03, 2.00e+03, 3.00e+03, 4.00e+03])

FIT_params = estimation(x, y)
print(FIT_params)

其中 RF 文件是:

def roots(params, x):
    a, b, c, d = params
    y = a * (b * x) ** c + d
    return y

def residuals(params, x, y):
    return y - func(params, x)

我想从主代码中删除 residuals 函数并通过从 RF 文件中调用来使用它,即通过激活代码第 residuals = RF.residuals 行。这样做会出现错误NameError: name 'func' is not defined。我把 func 参数放在 RF 的 residuals 函数中作为 def residuals(func, params, x, y): ,它将面临错误 TypeError: residuals() missing 1 required positional argument: 'y';似乎错误与此示例中残差函数的 forth 参数 有关,因为如果 func 参数放在 y 参数之后。我找不到问题的根源,但我想它一定与limitation of arguments in functions有关。如果有人能指导我理解错误及其解决方案,我将不胜感激。
是否可以将 residual 函数从主代码带到 RF 文件中?怎么样?

问题是您的文件 RF.py 中没有全局变量 func,因此无法找到它。一个简单的解决方案是向 residuals 函数添加一个附加参数:

# RF.py
def roots(params, x):
    a, b, c, d = params
    y = a * (b * x) ** c + d
    return y

def residuals(params, func, x, y):
    return y - func(params, x)

然后,您可以像这样在其他文件中使用它:

import numpy as np
from scipy.optimize import leastsq
from RF import residuals, roots as func

def estimation(func, x, y):
    p_guess = [1, 2, 0.5, 0]
    params, cov, infodict, mesg, ier = leastsq(residuals, p_guess, args=(func, x, y), full_output=True)
    return params

x = np.array([2.78e-03, 3.09e-03, 3.25e-03, 3.38e-03, 3.74e-03, 4.42e-03, 4.45e-03, 4.75e-03, 8.05e-03, 1.03e-02, 1.30e-02])
y = np.array([2.16e+02, 2.50e+02, 3.60e+02, 4.48e+02, 5.60e+02, 8.64e+02, 9.00e+02, 1.00e+03, 2.00e+03, 3.00e+03, 4.00e+03])

FIT_params = estimation(func, x, y)
print(FIT_params)