使用 scipy least_squares 时出现 ValueError
ValueError when using scipy least_squares
我正在尝试将我的数据拟合到一个函数中。我一直在使用此示例代码作为指南 http://docs.scipy.org/doc/scipy/reference/tutorial/optimize.html#example-of-solving-a-fitting-problem。我的代码如下:
from scipy.optimize import least_squares
import numpy as np
import matplotlib.pyplot as plt
def model(x, u):
return -x[0] * np.sqrt((x[1]/u) - 1)
def fun(x, u, y):
return y - model(x, u)
def jac(x, u, y):
J = np.empty((u.size, x.size))
J[:, 0] = np.sqrt((x[1]/u) - 1)
J[:, 1] = x[0] / (2 * u * np.sqrt((x[1]/u) - 1))
return J
u = np.array(T_h2)
y = np.array(lnR2)
x0 = np.array([0.1,0.2])
res = least_squares(fun, x0, jac=jac, bounds=(0, 100), args=(u, y), verbose=1)
print(res.x)
u_test = T_h2
y_test = model(res.x, u_test)
plt.plot(u, y, 'o', markersize=4, label='data')
plt.plot(u_test, y_test, label='fitted model')
plt.xlabel("u")
plt.ylabel("y")
plt.legend(loc='lower right')
plt.show()
但是,当我 运行 我的代码出现错误 "ValueError: Residuals are not finite in the initial point." 我该如何解决这个问题?
编辑:
T_h2 = [234.382, 234.353, 234.435, 234.709, 235.169, 235.803,
236.661, 237.688, 238.697, 239.658, 240.743, 241.813, 242.784, 243.739,
244.791, 245.675, 246.666, 247.615, 248.579, 249.481, 250.336, 251.311,
252.211, 253.058, 253.976, 254.831, 255.738, 256.599, 257.594, 258.482,
259.279, 260.233, 261.112, 262.103, 263.003, 263.9, 264.764, 265.688,
266.629, 267.491, 268.415, 269.285, 270.188, 271.129, 272., 272.935,
273.773, 274.714, 275.581, 276.549, 277.411, 278.334, 279.276, 280.146,
281.006, 281.905, 282.819, 283.803, 284.681, 285.513, 286.49, 287.324,
288.173, 289.105, 290.039, 290.991, 291.795, 292.694, 293.648, 294.522,
295.398, 296.296, 297.25, 298.134, 299.024, 299.912, 300.808, 301.732,
302.635, 303.603, 304.476, 305.35, 306.223, 307.18, 308.091, 308.938,
309.902, 310.792, 311.663, 312.566, 313.412, 314.284, 315.252, 316.126,
317.002, 317.913, 318.81, 319.669, 320.626, 321.523, 322.417, 323.281,
324.245]
lnR2 = [-16.333025681623091, -14.872111670594926, -14.892057965675207, -15.03694367579511, -14.388711659567424, -16.519631908799834, -14.047440985059174, -13.245512823492424, -12.012664970474015, -11.592570515633696, -11.415244487948224, -11.250423587326582, -11.043358068566182, -10.782270761445371, -10.57008012745084, -10.348870666290271, -10.191384942587591, -10.048855650333838, -9.9256240231933077, -9.7926739093465187, -9.6730532317943059, -9.5334101176483124, -9.3859588951369251, -9.2475985534653571, -9.1166053550752206, -9.0088611502583475, -8.8739120056364289, -8.7650034909964933, -8.6823151628382362, -8.6015380878989749, -8.5167589793011746, -8.4314862875533017, -8.364006279047107, -8.3069822249135825, -8.2571447519527315, -8.2111410588354676, -8.1684964170797887, -8.1396219459464945, -8.1149140801354562, -8.0937213212661057, -8.0742199830459658, -8.057615869538207, -8.0494949879212623, -8.0435977497085211, -8.0409171951906373, -8.0461036780308994, -8.0490116406609502, -8.0525194174270123, -8.0653078013251491, -8.0816100755759432, -8.0974305556597912, -8.1152346160995883, -8.1394956678268393, -8.1664274771185354, -8.1980306181968547, -8.2299693351364844, -8.2652082284364567, -8.3050428664294742, -8.3484319768441626, -8.3927260630797864, -8.4461326801347543, -8.5003378964708105, -8.5595337634985853, -8.6098956222034229, -8.6806395376767043, -8.7463523398937717, -8.819120148846844, -8.8938512284941815, -8.975439857789393, -9.0604311437041982, -9.160016977929974, -9.2544272624693313, -9.360134170694149, -9.48357662093877, -9.5792580093353656, -9.7144993201777972, -9.8715380997132574, -10.027248712603699, -10.177417977875871, -10.352374953002723, -10.517136866838991, -10.715774762340427, -10.913431451028842, -11.123817784132052, -11.345932175131191, -11.567233115011238, -11.775872934970939, -11.992878335292444, -12.21474839185972, -12.382545106102548, -12.45951145012326, -12.697558713087753, -12.870960915450144, -13.122795212623657, -13.096364398875277, -13.3741438677707, -13.323032960465998, -13.436772613480292, -13.561709556757362, -14.198404910172693, -13.896250284916482, -13.535947150817048, -14.727538560421378]
问题是:
- 案例A:你的初始点
- 情况 B:您的函数
model
给出起点 x0 = np.array([0.1,0.2])
(还有 u,y
),调用 fun(x0, u, y)
,发生以下情况:
np.sqrt((x[1]/u) - 1) # part of model(x, u)
= np.sqrt((0.2 / u) - 1)
= np.sqrt(some_near_zero_vector - 1) # because u much bigger than 0.2
= np.sqrt(some_near_minus_one_vector)
= NaN-vector, which is not finite! # because of negative components in sqrt
我正在尝试将我的数据拟合到一个函数中。我一直在使用此示例代码作为指南 http://docs.scipy.org/doc/scipy/reference/tutorial/optimize.html#example-of-solving-a-fitting-problem。我的代码如下:
from scipy.optimize import least_squares
import numpy as np
import matplotlib.pyplot as plt
def model(x, u):
return -x[0] * np.sqrt((x[1]/u) - 1)
def fun(x, u, y):
return y - model(x, u)
def jac(x, u, y):
J = np.empty((u.size, x.size))
J[:, 0] = np.sqrt((x[1]/u) - 1)
J[:, 1] = x[0] / (2 * u * np.sqrt((x[1]/u) - 1))
return J
u = np.array(T_h2)
y = np.array(lnR2)
x0 = np.array([0.1,0.2])
res = least_squares(fun, x0, jac=jac, bounds=(0, 100), args=(u, y), verbose=1)
print(res.x)
u_test = T_h2
y_test = model(res.x, u_test)
plt.plot(u, y, 'o', markersize=4, label='data')
plt.plot(u_test, y_test, label='fitted model')
plt.xlabel("u")
plt.ylabel("y")
plt.legend(loc='lower right')
plt.show()
但是,当我 运行 我的代码出现错误 "ValueError: Residuals are not finite in the initial point." 我该如何解决这个问题?
编辑:
T_h2 = [234.382, 234.353, 234.435, 234.709, 235.169, 235.803,
236.661, 237.688, 238.697, 239.658, 240.743, 241.813, 242.784, 243.739,
244.791, 245.675, 246.666, 247.615, 248.579, 249.481, 250.336, 251.311,
252.211, 253.058, 253.976, 254.831, 255.738, 256.599, 257.594, 258.482,
259.279, 260.233, 261.112, 262.103, 263.003, 263.9, 264.764, 265.688,
266.629, 267.491, 268.415, 269.285, 270.188, 271.129, 272., 272.935,
273.773, 274.714, 275.581, 276.549, 277.411, 278.334, 279.276, 280.146,
281.006, 281.905, 282.819, 283.803, 284.681, 285.513, 286.49, 287.324,
288.173, 289.105, 290.039, 290.991, 291.795, 292.694, 293.648, 294.522,
295.398, 296.296, 297.25, 298.134, 299.024, 299.912, 300.808, 301.732,
302.635, 303.603, 304.476, 305.35, 306.223, 307.18, 308.091, 308.938,
309.902, 310.792, 311.663, 312.566, 313.412, 314.284, 315.252, 316.126,
317.002, 317.913, 318.81, 319.669, 320.626, 321.523, 322.417, 323.281,
324.245]
lnR2 = [-16.333025681623091, -14.872111670594926, -14.892057965675207, -15.03694367579511, -14.388711659567424, -16.519631908799834, -14.047440985059174, -13.245512823492424, -12.012664970474015, -11.592570515633696, -11.415244487948224, -11.250423587326582, -11.043358068566182, -10.782270761445371, -10.57008012745084, -10.348870666290271, -10.191384942587591, -10.048855650333838, -9.9256240231933077, -9.7926739093465187, -9.6730532317943059, -9.5334101176483124, -9.3859588951369251, -9.2475985534653571, -9.1166053550752206, -9.0088611502583475, -8.8739120056364289, -8.7650034909964933, -8.6823151628382362, -8.6015380878989749, -8.5167589793011746, -8.4314862875533017, -8.364006279047107, -8.3069822249135825, -8.2571447519527315, -8.2111410588354676, -8.1684964170797887, -8.1396219459464945, -8.1149140801354562, -8.0937213212661057, -8.0742199830459658, -8.057615869538207, -8.0494949879212623, -8.0435977497085211, -8.0409171951906373, -8.0461036780308994, -8.0490116406609502, -8.0525194174270123, -8.0653078013251491, -8.0816100755759432, -8.0974305556597912, -8.1152346160995883, -8.1394956678268393, -8.1664274771185354, -8.1980306181968547, -8.2299693351364844, -8.2652082284364567, -8.3050428664294742, -8.3484319768441626, -8.3927260630797864, -8.4461326801347543, -8.5003378964708105, -8.5595337634985853, -8.6098956222034229, -8.6806395376767043, -8.7463523398937717, -8.819120148846844, -8.8938512284941815, -8.975439857789393, -9.0604311437041982, -9.160016977929974, -9.2544272624693313, -9.360134170694149, -9.48357662093877, -9.5792580093353656, -9.7144993201777972, -9.8715380997132574, -10.027248712603699, -10.177417977875871, -10.352374953002723, -10.517136866838991, -10.715774762340427, -10.913431451028842, -11.123817784132052, -11.345932175131191, -11.567233115011238, -11.775872934970939, -11.992878335292444, -12.21474839185972, -12.382545106102548, -12.45951145012326, -12.697558713087753, -12.870960915450144, -13.122795212623657, -13.096364398875277, -13.3741438677707, -13.323032960465998, -13.436772613480292, -13.561709556757362, -14.198404910172693, -13.896250284916482, -13.535947150817048, -14.727538560421378]
问题是:
- 案例A:你的初始点
- 情况 B:您的函数
model
给出起点 x0 = np.array([0.1,0.2])
(还有 u,y
),调用 fun(x0, u, y)
,发生以下情况:
np.sqrt((x[1]/u) - 1) # part of model(x, u)
= np.sqrt((0.2 / u) - 1)
= np.sqrt(some_near_zero_vector - 1) # because u much bigger than 0.2
= np.sqrt(some_near_minus_one_vector)
= NaN-vector, which is not finite! # because of negative components in sqrt