在 python 中的非线性最小二乘实现中,参数不会以较低的公差收敛
Parameters do not converge at a lower tolerance in nonlinear least square implementation in python
作为学习过程,我正在将我的一些 R 代码翻译成 Python,尤其是尝试 JAX
进行 autodiff。
在实现非线性最小二乘的函数中,当我将公差设置为 1e-8 时,估计的参数在多次迭代后几乎相同,但算法似乎从未收敛。
然而,R 代码在第 12 个中间点 tol=1e-8 和第 14 个中间点 tol=1e-9 处收敛。估计的参数与 Python 实现的参数几乎相同。
我认为这与浮点有关,但不确定我可以改进哪一步以使收敛速度与 R 中看到的一样快。
这是我的代码,大部分步骤和R一样
import jax
import jax.numpy as jnp
import numpy as np
import scipy.linalg as ola
def update_parm(X, y, fun, dfun, parm, theta, wt):
len_y = len(y)
mean_fun = fun(X, parm)
if (type(wt) == bool):
if (wt):
var_fun = np.exp(theta * np.log(mean_fun))
sqrtW = 1 / np.sqrt(var_fun ** 2)
else:
sqrtW = 1
else:
sqrtW = wt
gradX = dfun(x, parm)
weighted_X = sqrtW.reshape(len_y, 1) * gradX
z = gradX @ parm + (y - mean_fun)
weighted_z = sqrtW * z
qr_gradX = ola.qr(weighted_X, mode="economic")
Q = qr_gradX[0]
R = qr_gradX[1]
new_parm = ola.solve(R, np.dot(Q.T, weighted_z))
return new_parm
def nls_irwls(X, y, fun, dfun, init, theta = 1, tol = 1e-8, maxiter = 500):
old_parm = init
iter = 0
while (iter < maxiter):
new_parm = update_parm(X, y, fun, dfun, parm=old_parm, theta=theta, wt=True)
parm_diff = np.max(np.abs(new_parm - old_parm) / np.abs(old_parm))
print(parm_diff)
if (parm_diff < tol) :
break
else:
old_parm = new_parm
iter += 1
print(new_parm)
if (iter == maxiter):
print("The algorithm failed to converge")
else:
return {"Estimated coefficient": new_parm}
x = np.array([0.25, 0.5, 0.75, 1, 1.25, 2, 3, 4, 5, 6, 8])
y = np.array([2.05, 1.04, 0.81, 0.39, 0.30, 0.23, 0.13, 0.11, 0.08, 0.10, 0.06])
def model(x, W):
comp1 = jnp.exp(W[0])
comp2 = jnp.exp(-jnp.exp(W[1]) * x)
comp3 = jnp.exp(W[2])
comp4 = jnp.exp(-jnp.exp(W[3]) * x)
return comp1 * comp2 + comp3 * comp4
init = np.array([0.69, 0.69, -1.6, -1.6])
#autodiff
model_grad = jax.jit(jax.jacfwd(model, argnums=1))
#manual derivative
def dModel(x, W):
e1 = np.exp(W[1])
e2 = np.exp(W[3])
e5 = np.exp(-(x * e1))
e6 = np.exp(-(x * e2))
e7 = np.exp(W[0])
e8 = np.exp(W[2])
b1 = e5 * e7
b2 = -(x * e5 * e7 * e1)
b3 = e6 * e8
b4 = -(x * e6 * e8 * e2)
return np.array([b1, b2, b3, b4]).T
nls_irwls(x, y, model, model_grad, init=init, theta=1, tol=1e-8, maxiter=50)
nls_irwls(x, y, model, dModel, init=init, theta=1, tol=1e-8, maxiter=50)
需要注意的一件事是,默认情况下,JAX 以 32 位执行计算,而 R 和 numpy 等工具以 64 位执行计算。由于 1E-8
处于 32 位浮点精度的边缘,我怀疑这就是您的程序无法收敛的原因。
您可以通过将其放在脚本的开头来启用 64 位计算:
from jax import config
config.update('jax_enable_x64', True)
执行此操作后,您的程序按预期收敛。有关详细信息,请参阅 JAX Sharp Bits: Double Precision。
作为学习过程,我正在将我的一些 R 代码翻译成 Python,尤其是尝试 JAX
进行 autodiff。
在实现非线性最小二乘的函数中,当我将公差设置为 1e-8 时,估计的参数在多次迭代后几乎相同,但算法似乎从未收敛。
然而,R 代码在第 12 个中间点 tol=1e-8 和第 14 个中间点 tol=1e-9 处收敛。估计的参数与 Python 实现的参数几乎相同。
我认为这与浮点有关,但不确定我可以改进哪一步以使收敛速度与 R 中看到的一样快。
这是我的代码,大部分步骤和R一样
import jax
import jax.numpy as jnp
import numpy as np
import scipy.linalg as ola
def update_parm(X, y, fun, dfun, parm, theta, wt):
len_y = len(y)
mean_fun = fun(X, parm)
if (type(wt) == bool):
if (wt):
var_fun = np.exp(theta * np.log(mean_fun))
sqrtW = 1 / np.sqrt(var_fun ** 2)
else:
sqrtW = 1
else:
sqrtW = wt
gradX = dfun(x, parm)
weighted_X = sqrtW.reshape(len_y, 1) * gradX
z = gradX @ parm + (y - mean_fun)
weighted_z = sqrtW * z
qr_gradX = ola.qr(weighted_X, mode="economic")
Q = qr_gradX[0]
R = qr_gradX[1]
new_parm = ola.solve(R, np.dot(Q.T, weighted_z))
return new_parm
def nls_irwls(X, y, fun, dfun, init, theta = 1, tol = 1e-8, maxiter = 500):
old_parm = init
iter = 0
while (iter < maxiter):
new_parm = update_parm(X, y, fun, dfun, parm=old_parm, theta=theta, wt=True)
parm_diff = np.max(np.abs(new_parm - old_parm) / np.abs(old_parm))
print(parm_diff)
if (parm_diff < tol) :
break
else:
old_parm = new_parm
iter += 1
print(new_parm)
if (iter == maxiter):
print("The algorithm failed to converge")
else:
return {"Estimated coefficient": new_parm}
x = np.array([0.25, 0.5, 0.75, 1, 1.25, 2, 3, 4, 5, 6, 8])
y = np.array([2.05, 1.04, 0.81, 0.39, 0.30, 0.23, 0.13, 0.11, 0.08, 0.10, 0.06])
def model(x, W):
comp1 = jnp.exp(W[0])
comp2 = jnp.exp(-jnp.exp(W[1]) * x)
comp3 = jnp.exp(W[2])
comp4 = jnp.exp(-jnp.exp(W[3]) * x)
return comp1 * comp2 + comp3 * comp4
init = np.array([0.69, 0.69, -1.6, -1.6])
#autodiff
model_grad = jax.jit(jax.jacfwd(model, argnums=1))
#manual derivative
def dModel(x, W):
e1 = np.exp(W[1])
e2 = np.exp(W[3])
e5 = np.exp(-(x * e1))
e6 = np.exp(-(x * e2))
e7 = np.exp(W[0])
e8 = np.exp(W[2])
b1 = e5 * e7
b2 = -(x * e5 * e7 * e1)
b3 = e6 * e8
b4 = -(x * e6 * e8 * e2)
return np.array([b1, b2, b3, b4]).T
nls_irwls(x, y, model, model_grad, init=init, theta=1, tol=1e-8, maxiter=50)
nls_irwls(x, y, model, dModel, init=init, theta=1, tol=1e-8, maxiter=50)
需要注意的一件事是,默认情况下,JAX 以 32 位执行计算,而 R 和 numpy 等工具以 64 位执行计算。由于 1E-8
处于 32 位浮点精度的边缘,我怀疑这就是您的程序无法收敛的原因。
您可以通过将其放在脚本的开头来启用 64 位计算:
from jax import config
config.update('jax_enable_x64', True)
执行此操作后,您的程序按预期收敛。有关详细信息,请参阅 JAX Sharp Bits: Double Precision。