使用 lmfit 拟合曲线时为参数设置初始值

Setting initial values to parameters when fitting a curve with lmfit

我正在使用复合模型(两个高斯模型)来拟合具有 lmfit 的曲线,拟合结果似乎完全取决于我给出的初始值。设置初始参数的最佳方法是什么?我知道使用 lmfit 的三种不同方法:addsetset_param_hint,但我不完全理解其中的区别。文档表明 set_param_hint 是执行此操作的好方法,但我想知道与其他方法的区别。

这是我的代码示例,使用不同的方法(addset)来说明混淆:

from lmfit import Model, Parameters
import matplotlib.pyplot as plt 

x_val = [4460.1758349164, 4460.375833832813, 4460.575832749225, 4460.775831665638, 4460.975830582051, 4461.175829498463, 4461.375828414875, 4461.575827331288, 4461.775826247701, 4461.975825164113, 4462.175824080526, 4462.375822996938, 4462.57582191335, 4462.775820829764, 4462.975819746176, 4463.1758186625875, 4463.375817579001, 4463.575816495413, 4463.775815411826, 4463.975814328239, 4464.175813244651, 4464.375812161064, 4464.575811077476, 4464.775809993889, 4464.9758089103025, 4465.175807826714, 4465.375806743126, 4465.57580565954, 4465.775804575952, 4465.975803492364, 4466.175802408777, 4466.375801325189, 4466.575800241602, 4466.775799158015, 4466.975798074427, 4467.175796990839, 4467.375795907252, 4467.575794823664, 4467.7757937400775, 4467.97579265649, 4468.175791572902, 4468.3757904893155, 4468.575789405728, 4468.77578832214, 4468.975787238553, 4469.175786154965, 4469.375785071377, 4469.5757839877915, 4469.775782904203, 4469.975781820615, 4470.175780737029, 4470.37577965344, 4470.575778569853, 4470.775777486266, 4470.975776402678, 4471.1757753190905, 4471.375774235503, 4471.575773151916, 4471.7757720683285, 4471.975770984741, 4472.175769901153, 4472.375768817566, 4472.575767733979, 4472.775766650391, 4472.9757655668045, 4473.175764483216, 4473.375763399628, 4473.575762316042, 4473.775761232454, 4473.975760148866, 4474.175759065279, 4474.375757981692, 4474.575756898104, 4474.775755814518, 4474.975754730929, 4475.1757536473415, 4475.375752563754, 4475.575751480167, 4475.77575039658, 4475.975749312992, 4476.175748229404, 4476.3757471458175, 4476.57574606223, 4476.775744978642, 4476.975743895055, 4477.175742811467, 4477.375741727879, 4477.575740644294, 4477.775739560705, 4477.9757384771165, 4478.17573739353, 4478.375736309943, 4478.575735226355, 4478.775734142768, 4478.97573305918, 4479.175731975593, 4479.375730892006, 4479.575729808418, 4479.7757287248305, 4479.975727641243, 4480.175726557655, 4480.375725474069, 4480.575724390481, 4480.775723306892, 4480.975722223307, 4481.175721139718, 4481.375720056131, 4481.575718972544, 4481.775717888956, 4481.975716805368, 4482.175715721783, 4482.375714638194, 4482.5757135546055, 4482.77571247102, 4482.975711387431, 4483.175710303844, 4483.375709220257, 4483.575708136668, 4483.7757070530815, 4483.975705969494, 4484.175704885907, 4484.3757038023205, 4484.575702718732, 4484.775701635144, 4484.975700551557]
y_val = [1.0438815599549134, 0.9861559707471772, 1.0056426645990315, 1.0016074526378649, 1.0452997007422666, 0.992212205281684, 1.0365215397316232, 1.0218869075138342, 1.0055580715537948, 1.0156218890501965, 1.028214904229718, 0.9935787796492273, 1.02364139796149, 1.0179358129807576, 1.035762676388034, 1.049932333954558, 1.0402954847373662, 1.0169711103176595, 1.0340240575460198, 1.049747768424791, 1.0175400582158902, 1.0103838602023636, 1.0680006544649665, 0.975519363154844, 1.0202812597671398, 0.9695222898779196, 1.0052738395140506, 1.0053855702044892, 0.9935941046265898, 0.986614047747308, 0.9986655992818708, 0.9999356062287996, 1.0240484329659438, 0.9819990493350282, 1.000327008341581, 0.9717165926477822, 0.9879546941197598, 0.9842935196212136, 1.0222486380060392, 0.975275958755044, 1.0498618707695202, 1.025608170066069, 0.9909686718827492, 0.975939608797198, 0.9467728492315236, 0.9480619167488604, 0.9600094590732424, 0.9636132733406744, 0.9944894010124092, 0.9426361826831244, 0.9782473212039978, 0.9378327202091502, 0.9488207621805942, 0.9669396283466724, 0.9432847772067492, 0.9015761099378126, 0.9135968691755808, 0.8939703886252973, 0.8573607070423116, 0.868161237954455, 0.8849968824099054, 0.885539805042943, 0.844515618445441, 0.8842305221856582, 0.8877296440122721, 0.8821343557372545, 0.9075013206055316, 0.8660876250948828, 0.9127519356948968, 0.8952841088988195, 0.9602437689940024, 1.0375435216069926, 1.1326450855548746, 1.2528373417955827, 1.359064567678794, 1.7397790583320276, 2.2955575263013603, 2.6313330486608075, 2.7696361971739485, 2.2290943507722045, 1.5299780348545342, 1.1265789292075985, 0.9761209131908825, 0.9552781525369406, 0.9872235913023412, 0.9554892446527146, 0.9693081918466234, 0.9565660500653812, 0.9460822542921022, 0.9266113291876116, 0.9704238862428936, 0.8915634335508363, 0.9158114443978326, 0.9466235269126626, 0.9451751549645016, 0.946265616542422, 0.9367300273679332, 0.971583009744108, 0.9435038781374095, 0.9892258250694016, 0.9754689843339546, 0.9578096187257352, 0.9649331079033204, 0.9709409505255512, 0.9818618967434926, 0.9732673864230984, 0.9970556441582832, 0.9810274934718626, 0.939447766493294, 1.0112673683488067, 1.0191757378152404, 0.9835438808599056, 0.985619193341479, 0.9862022169399436, 1.0458502824889473, 0.9594215029321304, 0.9971740675615232, 0.9974173269531228, 0.9955615254192632, 1.03531504592408, 1.0077373609120324, 1.0009705059358802, 1.0206465226122023, 0.9591259867321692, 1.0148009048782651]
err = [0.0356014742203398, 0.028023844164620268, 0.02706632229192564, 0.026921086598994004, 0.03404330335778127, 0.03225575706454388, 0.032951103033851084, 0.02550825680398673, 0.029497494361785826, 0.025198158411558855, 0.03492983187381606, 0.03163083328614311, 0.027704308525917317, 0.03494848818894923, 0.030014846715378605, 0.035741193441217865, 0.03078218636873445, 0.023901828310539986, 0.03628052312062977, 0.035025392619838884, 0.03976648591093106, 0.02780543058799098, 0.040944290884658216, 0.034099200916427784, 0.03205306075906642, 0.03326464028563125, 0.02337626476347709, 0.026083179277841928, 0.028218666012639764, 0.04596683621166614, 0.03305076066644353, 0.028735271103684058, 0.03966961113288402, 0.029082468902683317, 0.028285569241373782, 0.031786755430356486, 0.024404779108853858, 0.026129373614987225, 0.03286225269330064, 0.0337885577191429, 0.037435419977679456, 0.027487698789152224, 0.02431364360404831, 0.03695118040711042, 0.05126648287442151, 0.04107233842769607, 0.03979475798462972, 0.03740966627043441, 0.030822212943554483, 0.05058778089333995, 0.03679756266194399, 0.06998625264367124, 0.03794562219242631, 0.03310200401794181, 0.05331291012493153, 0.07986441365482183, 0.06900775599719644, 0.08219705262724887, 0.105874487190267, 0.10342988581359616, 0.08019517918681268, 0.08692292530550771, 0.113978355113441, 0.09103658535785254, 0.08330700273089763, 0.09023708512793886, 0.06817086680024753, 0.09733919241124256, 0.06544890074726599, 0.0734814660643719, 0.03886987577445243, 0.033151154927677444, 0.07391828687885042, 0.12902165265322205, 0.16726327412564035, 0.2647359446458325, 0.4179242572687573, 0.4541710636471308, 0.37629500747418, 0.3240957615905829, 0.2051963217492506, 0.07266588290723769, 0.036843269718234525, 0.05208312696082423, 0.026365044379277364, 0.04304862993377523, 0.03843764665504956, 0.04830679502266177, 0.057360927302557374, 0.06550536976828003, 0.03740542151100047, 0.08629363539797757, 0.0592656471636982, 0.0498517781492637, 0.04573315868341099, 0.04517963641752231, 0.056639635044659624, 0.03210377504774208, 0.04591194405625765, 0.0270964657791688, 0.04062592174152552, 0.039282823305607964, 0.034139260725464984, 0.030730966608705536, 0.0257602056376013, 0.03354067520866908, 0.02882918823621897, 0.02923878376561263, 0.0564366148759929, 0.036253452623873764, 0.02504495217929072, 0.040091125177588, 0.02658634779690779, 0.02667635918064909, 0.03370366542143037, 0.039955314845191145, 0.03135622152872908, 0.059506780695663314, 0.025254987757541952, 0.038034923152503126, 0.02883708074109163, 0.02606771741119524, 0.039180311098300204, 0.04173873330363966, 0.024621574626190273]

def gaussian(x, amp, cen, wid):
    "1-d gaussian: gaussian(x, amp, cen, wid)"
    return 1 - amp*np.exp(-(x-cen)**2 /(2*(wid/2.355)**2))

def nebu(x, amp, cen, wid):
    "1-d gaussian: gaussian(x, amp, cen, wid)"
    return -amp*np.exp(-(x-cen)**2 /(2*(wid/2.355)**2))

gauss = Model(gaussian)
pars = Parameters()
pars.add('amp', value=0.2, min=0.01, max=1. )
#pars.add('cen', value=x_val[argmax(y_val)], min=4472, max=4478)
pars.add('cen', value=4475, min=4470, max=4480)
pars.add('wid', value=5, min=1, max=10.)                

gauss2 = Model(nebu, prefix='neb_')
pars.update(gauss2.make_params())                                        

pars['neb_amp'].set(-1, min=-4, max=-0.1)
#pars['neb_cen'].set(x_val[argmax(y_val)], min=4470, max=4480)
pars['neb_cen'].set(4475, min=4470, max=4480)
pars['neb_wid'].set(0.8, min=0.1, max=2.)

mod = gauss + gauss2

result = mod.fit(y_val, pars, x=x_val, weights=[1./x for x in err])
comp = result.eval_components(result.params, x=x_val)

print(result.fit_report())

fig, ax = plt.subplots(figsize=(6,6))    
plt.plot(x_val, y_val, 'k-', lw=2, label='data')
plt.plot(x_val, result.init_fit, '--', c='gray', label='initial pars')
plt.plot(x_val, result.best_fit, 'r-', lw=2, label='model fit')
plt.plot(x_val, comp['gaussian'], '--', c='limegreen', lw=2, label='gauss')
plt.plot(x_val, 1+comp['neb_'], '--', c='orange', lw=2, label='gauss2')                            
plt.legend(loc='best',fontsize=10, handlelength=2, frameon=False)                               

plt.show()

这是合身报告:

[[Model]]
    (Model(gaussian) + Model(nebu, prefix='neb_'))
[[Fit Statistics]]
    # fitting method   = leastsq
    # function evals   = 155
    # data points      = 125
    # variables        = 6
    chi-square         = 66.8949742
    reduced chi-square = 0.56214264
    Akaike info crit   = -66.1487371
    Bayesian info crit = -49.1788547
[[Variables]]
    amp:      0.07859727 +/- 0.01341948 (17.07%) (init = 0.2)
    cen:      4474.64017 +/- 0.37950199 (0.01%) (init = 4477)
    wid:      8.08798219 +/- 0.87892008 (10.87%) (init = 5)
    neb_amp: -1.44779784 +/- 0.14587605 (10.08%) (init = -1)
    neb_cen:  4475.51782 +/- 0.02445057 (0.00%) (init = 4475)
    neb_wid:  1.06240785 +/- 0.05036105 (4.74%) (init = 0.8)
[[Correlations]] (unreported correlations are < 0.100)
    C(amp, wid)         = -0.721
    C(neb_amp, neb_wid) =  0.652
    C(amp, neb_wid)     =  0.559
    C(wid, neb_wid)     = -0.373
    C(neb_cen, neb_wid) = -0.163
    C(neb_amp, neb_cen) = -0.150
    C(cen, neb_cen)     =  0.138
    C(amp, neb_amp)     =  0.132
    C(amp, neb_cen)     = -0.132
    C(wid, neb_cen)     =  0.123

在某些情况下选择不同的初始值会改变拟合结果(例如cen=4472),所以我想知道这是否与给出初始值的方法有关,或者只是噪声和数据错误导致拟合效果不佳。

在 lmfit 中有许多不同的方法来指定参数的初始值。

Parameters.add()Parameter 添加到 Parameters 有序字典。以这种方式添加一个Parameter时,您可以设置一个初始值并设置其他属性(尤其是minmaxvaryexpr)。

Parameter.set() 为现有的设置一个或多个属性(valueminmaxvaryexprParameter。您也可以只显式设置这些属性,如

pars['neb_wid'].value = 0.8

这些都适用于 Parameter 对象和 Parameters 集合。

此外,lmfit.Model 将有一个 make_params() 方法为该模型创建一个 Parameters 集合。你在你的例子中使用它。此方法可以为任何参数取初始值,或者您可以在创建后修改生成的Parameters

一个 Model 可能有一个或多个参数提示来帮助 Model 创建它的 Parameters。参数提示可能包括初始值,但通常用于设置边界或表达式,以便可以表达“对于此模型,参数 foo 必须为正,参数 bar 必须为 = 2*foo - baz。 这样,参数提示属于模型

这些评论都是关于如何设置参数初始值的机制。决定这些初始值应该是什么是完全不同的事情。使用全局求解器,如微分进化或(可能更好)AMPGO 或蛮力步进有限数量的选项(所有这些都在 lmfit 中可用)可能有用,但可能很耗时。

可以肯定的是,如果有两个宽度相同数量级的重叠高斯分布(你们两个都有 value=4475, min=4470, max=4480),将很难区分,并且会导致拟合不稳定。您是否真的希望高斯叠加具有几乎相同的中心(但不相同,否则您会限制它们相等),不 that 不同,并且具有不同符号的振幅?如果是这样,是的,这对我来说似乎是个难题!