scipy.optimize.curve_fit 线性拟合

Linear fit with scipy.optimize.curve_fit

我上周才开始编程,所以请温柔点;)

我尝试做的是用 curve_fit 进行线性拟合以确定对斜率的两个贡献。我试过了:

import os
from os import listdir
from os.path import isfile, join
from scipy.optimize import curve_fit
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.axes as ax 
from scipy import asarray as ar,exp

freq_lw_tuple = [('10.0', 61.32542701946727), ('10.5', 39.367147015318501), ('11.0', 58.432147581817077), ('11.5', 68.819336676144317), ('12.0', 71.372078906193025), ('12.5', 73.113662907539336), ('13.0', 75.855316075062603), ('13.5', 76.798724771266322), ('14.0', 79.065657225891329), ('14.5', 81.637345805693897), ('15.0', 82.407248034320034)]
def func_lw(x_lw,alpha,g):
    return (alpha/g)*x_lw

#define constants
m_e = 9.109383*(10**(-31)) #mass of electron in [kg]
q_e = 1.602176*(10**(-19)) #value of electron charge
pi = 3.14159               #pi
#unzip the list of tuples
unzipped = list(zip(*freq_lw_tuple))

xval_list = []
yval_list = []
for k in range(0,len(unzipped[0])):
    x_value = (8/np.sqrt(3))*((2*pi*m_e*float(unzipped[0][k])*10**9)/q_e)   #calculate x values
    xval_list.append(x_value)
    y_value = unzipped[1][k]*10**(-4)*4*pi*10**(-7)    #transform unit of y values
    yval_list.append(y_value)


start_params2 = [0.01,2]
fitted_params, pcov = curve_fit(func_lw, xval_list, yval_list, start_params2)

它实际上给出了一些结果,但是当我想

print(func_lw(xval_list,*fitted_params))

我只是得到一个空列表,这可能就是我不能

的原因
plt.plot(xval_list, func_lw(xval_list, *fitted_params))

(这会产生如下错误:x 和 y 必须具有相同的第一维)

[编辑:为 freq_lw_tuple 和导入添加了一些数据]

Python 不能乘以列表和标量。 这就是您导入 from scipy import asarray as ar 的原因。您也可以使用 numpy。 所以当你调用 func_lw 时,你应该给它一个数组,而不是一个列表。

func_lw(ar(xval_list), fitted_params[0], fitted_params[1] )
plt.plot(ar(xval_list), func_lw(ar(xval_list), *fitted_params))
plt.scatter(ar(xval_list), ar(yval_list))

关于优化函数,如果你在公式中使用(alpha/g)*x_lw,你实际上只有一个参数(斜率)。我会改用

def func_lw(x_lw, slope, offset):
    return slope*x_lw + offset

编辑: 顺便说一句,你应该使用 "list comprehension",并做类似

的事情
xval_list = [(8/np.sqrt(3))*((2*pi*m_e*float(x)*10**9)/q_e) for x, y in freq_lw_tuple]
yval_list = [y*10**(-4)*4*pi*10**(-7) for x, y in freq_lw_tuple]