Python:使用一般的scipy.optimize.curve_fit函数

Python: Use general scipy.optimize.curve_fit function

我想对 python 中的一些数据进行曲线拟合。我的程序如下所示:

from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

def lin(x, a, b,c):
    return a*x+b

def exp(x, a, b, c):
    return a*np.exp(b*x)+c

def ln(x, a, b, c):
    return a*np.log(b+x)+c

x_dummy = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
y_dummy = np.array([9.2, 9.9, 10.0, 11.2, 10.2, 12.6, 10.0, 11.6, 12.2])



popt, _ = curve_fit(lin, x_dummy[:-2], y_dummy[:-1])

y_approx = lin(x_dummy, popt[0], popt[1], popt[2])

print(y_approx[-1])



print(popt)
print(mean_squared_error(y_dummy[:-1], y_approx[:-2]))


plt.plot(x_dummy[:-1], y_dummy, color='blue')
plt.plot(x_dummy, y_approx, color='green')
plt.show()

我现在的目标是一个称为 fn 的通用函数,它可以有一些参数,例如作为某种意义上的字符串,调用

popt, _ = curve_fit(fn('lin' or 'exp' or 'ln'), x_dummy[:-2], y_dummy[:-1])

的含义相同
popt, _ = curve_fit(lin or exp or ln, x_dummy[:-2], y_dummy[:-1])

背景:我想生成一些数组 = ['lin', 'exp', 'ln'] 并遍历所有三种可能的曲线拟合并计算再现的最小值平方误差。

找到了一些方法,但也许是更简单的方法:

from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

class FunctionCollector():
    def __init__(self):
        self.name = 'lin'

    def setFunc(self, name):
        self.name = name

    def lin(self, x, a, b, c):
        return a*x+b

    def exp(self, x, a, b, c):
        return a*np.exp(b*x)+c

    def ln(self, x, a, b, c):
        return a*np.log(b+x)+c

    def fn(self, x, a, b, c):
        if self.name == 'lin':
            return self.lin(x, a,b,c)
        elif self.name == 'exp':
            return self.exp(x,a,b,c)
        elif self.name == 'ln':
            return self.ln(x,a,b,c)
        return 0



def l(x,a,b,c):
    return a * x + b
x_dummy = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
y_dummy = np.array([9.2, 9.9, 10.0, 11.2, 10.2, 12.6, 10.0, 11.6, 12.2])


#noise = 5*np.random.normal(size=y_dummy.size)
#y_dummy = y_dummy + noise

f = FunctionCollector()

popt, _ = curve_fit(f.fn, x_dummy[:-2], y_dummy[:-1])
y_approx = f.fn(x_dummy, popt[0], popt[1], popt[2])

print(y_approx[-1])



print(popt)
print(mean_squared_error(y_dummy[:-1], y_approx[:-2]))


plt.plot(x_dummy[:-1], y_dummy, color='blue')
plt.plot(x_dummy, y_approx, color='green')
plt.show()