从另一个文件导入时,我在为 scipy 的 leastsq 优化实现残差函数时遇到问题
I have problem implementing residuals function for leastsq optimization of scipy when importing it from another file
我写了一段函数相互调用的代码。工作代码如下:
import numpy as np
from scipy.optimize import leastsq
import RF
func = RF.roots
# residuals = RF.residuals
def residuals(params, x, y):
return y - func(params, x)
def estimation(x, y):
p_guess = [1, 2, 0.5, 0]
params, cov, infodict, mesg, ier = leastsq(residuals, p_guess, args=(x, y), full_output=True)
return params
x = np.array([2.78e-03, 3.09e-03, 3.25e-03, 3.38e-03, 3.74e-03, 4.42e-03, 4.45e-03, 4.75e-03, 8.05e-03, 1.03e-02, 1.30e-02])
y = np.array([2.16e+02, 2.50e+02, 3.60e+02, 4.48e+02, 5.60e+02, 8.64e+02, 9.00e+02, 1.00e+03, 2.00e+03, 3.00e+03, 4.00e+03])
FIT_params = estimation(x, y)
print(FIT_params)
其中 RF 文件是:
def roots(params, x):
a, b, c, d = params
y = a * (b * x) ** c + d
return y
def residuals(params, x, y):
return y - func(params, x)
我想从主代码中删除 residuals 函数并通过从 RF 文件中调用来使用它,即通过激活代码第 residuals = RF.residuals
行。这样做会出现错误NameError: name 'func' is not defined
。我把 func 参数放在 RF 的 residuals 函数中作为 def residuals(func, params, x, y):
,它将面临错误 TypeError: residuals() missing 1 required positional argument: 'y'
;似乎错误与此示例中残差函数的 forth 参数 有关,因为如果 func 参数放在 y 参数之后。我找不到问题的根源,但我想它一定与limitation of arguments in functions有关。如果有人能指导我理解错误及其解决方案,我将不胜感激。
是否可以将 residual 函数从主代码带到 RF 文件中?怎么样?
问题是您的文件 RF.py
中没有全局变量 func
,因此无法找到它。一个简单的解决方案是向 residuals
函数添加一个附加参数:
# RF.py
def roots(params, x):
a, b, c, d = params
y = a * (b * x) ** c + d
return y
def residuals(params, func, x, y):
return y - func(params, x)
然后,您可以像这样在其他文件中使用它:
import numpy as np
from scipy.optimize import leastsq
from RF import residuals, roots as func
def estimation(func, x, y):
p_guess = [1, 2, 0.5, 0]
params, cov, infodict, mesg, ier = leastsq(residuals, p_guess, args=(func, x, y), full_output=True)
return params
x = np.array([2.78e-03, 3.09e-03, 3.25e-03, 3.38e-03, 3.74e-03, 4.42e-03, 4.45e-03, 4.75e-03, 8.05e-03, 1.03e-02, 1.30e-02])
y = np.array([2.16e+02, 2.50e+02, 3.60e+02, 4.48e+02, 5.60e+02, 8.64e+02, 9.00e+02, 1.00e+03, 2.00e+03, 3.00e+03, 4.00e+03])
FIT_params = estimation(func, x, y)
print(FIT_params)
我写了一段函数相互调用的代码。工作代码如下:
import numpy as np
from scipy.optimize import leastsq
import RF
func = RF.roots
# residuals = RF.residuals
def residuals(params, x, y):
return y - func(params, x)
def estimation(x, y):
p_guess = [1, 2, 0.5, 0]
params, cov, infodict, mesg, ier = leastsq(residuals, p_guess, args=(x, y), full_output=True)
return params
x = np.array([2.78e-03, 3.09e-03, 3.25e-03, 3.38e-03, 3.74e-03, 4.42e-03, 4.45e-03, 4.75e-03, 8.05e-03, 1.03e-02, 1.30e-02])
y = np.array([2.16e+02, 2.50e+02, 3.60e+02, 4.48e+02, 5.60e+02, 8.64e+02, 9.00e+02, 1.00e+03, 2.00e+03, 3.00e+03, 4.00e+03])
FIT_params = estimation(x, y)
print(FIT_params)
其中 RF 文件是:
def roots(params, x):
a, b, c, d = params
y = a * (b * x) ** c + d
return y
def residuals(params, x, y):
return y - func(params, x)
我想从主代码中删除 residuals 函数并通过从 RF 文件中调用来使用它,即通过激活代码第 residuals = RF.residuals
行。这样做会出现错误NameError: name 'func' is not defined
。我把 func 参数放在 RF 的 residuals 函数中作为 def residuals(func, params, x, y):
,它将面临错误 TypeError: residuals() missing 1 required positional argument: 'y'
;似乎错误与此示例中残差函数的 forth 参数 有关,因为如果 func 参数放在 y 参数之后。我找不到问题的根源,但我想它一定与limitation of arguments in functions有关。如果有人能指导我理解错误及其解决方案,我将不胜感激。
是否可以将 residual 函数从主代码带到 RF 文件中?怎么样?
问题是您的文件 RF.py
中没有全局变量 func
,因此无法找到它。一个简单的解决方案是向 residuals
函数添加一个附加参数:
# RF.py
def roots(params, x):
a, b, c, d = params
y = a * (b * x) ** c + d
return y
def residuals(params, func, x, y):
return y - func(params, x)
然后,您可以像这样在其他文件中使用它:
import numpy as np
from scipy.optimize import leastsq
from RF import residuals, roots as func
def estimation(func, x, y):
p_guess = [1, 2, 0.5, 0]
params, cov, infodict, mesg, ier = leastsq(residuals, p_guess, args=(func, x, y), full_output=True)
return params
x = np.array([2.78e-03, 3.09e-03, 3.25e-03, 3.38e-03, 3.74e-03, 4.42e-03, 4.45e-03, 4.75e-03, 8.05e-03, 1.03e-02, 1.30e-02])
y = np.array([2.16e+02, 2.50e+02, 3.60e+02, 4.48e+02, 5.60e+02, 8.64e+02, 9.00e+02, 1.00e+03, 2.00e+03, 3.00e+03, 4.00e+03])
FIT_params = estimation(func, x, y)
print(FIT_params)