如何使用 google jax 进行曲线拟合?
how to do curve fitting using google jax?
扩展 http://implicit-layers-tutorial.org/neural_odes/ 中的示例,我试图使用 google jax 模仿 scipy 和 scipy.optimize.curve_fit
中的曲线拟合函数。要拟合的函数是一阶 ODE。
#Generate toy data for first order ode.
import jax.numpy as jnp
import jax
import numpy as np
#input data
u = np.zeros(100)
u[10:50] = 1
t = np.arange(len(u))
u = jnp.array(u)
#first order ODE
def f(y,t,k,tau,u):
return (k*u[t]-y)/tau
#Euler integration
def odeint_euler(f, y0, t, *args):
def step(state, t):
y_prev, t_prev = state
dt = t - t_prev
y = y_prev + dt * f(y_prev, t_prev, *args)
return (y, t), y
_, ys = jax.lax.scan(step, (y0, t[0]), t[1:])
return ys
pred = odeint_euler(f, jnp.array([0.0]),t,2.,5.,u)
pred_noise = pred.reshape(-1) + 0.05* np.random.randn(len(pred)) # this is the data to be fitted
# define loss function
def loss_function(params,u,targets):
k,tau = params
pred = odeint_euler(f, jnp.array([0.0]),t,k,tau,u)
return jnp.sum((pred-targets)**2)
def update(params, u, targets):
grads = jax.grad(loss_function)(params,u, targets)
return [w - 0.0001 * dw for w,dw in zip(params, grads)]
updated_params = jnp.array([1.0,2.0]) #initial parameters
for i in range(100):
updated_params = update(updated_params, u, pred_noise)
print(updated_params)
代码运行良好。然而,与 scipy 曲线拟合相比,这个 运行 非常慢。即使在 500、1000 次迭代之后,解决方案的准确性也不好。
上面的代码有什么问题?知道如何使代码 运行 更快并获得更准确的解决方案吗? jax
有没有更好的曲线拟合方法?
我发现您的方法存在两个总体问题:
- 您的代码 运行 运行缓慢的原因是您在 Python 中进行循环,这会在每个循环中产生 JAX 的调度开销。我建议使用 JAX 的内置工具来最小化损失函数;例如:
from jax.scipy.optimize import minimize
result = minimize(
loss_function, x0=jnp.array([1.0,2.0]),
method='BFGS', args=(u, pred_noise))
- 您的准确度不接近 scipy 的原因可能是因为 JAX 默认为 32 位计算(参见 Double (64 bit) Precision)。要 运行 您的 64 位代码,您可以 运行 此块在任何其他导入之前:
from jax import config
config.update('jax_enable_x64', True)
扩展 http://implicit-layers-tutorial.org/neural_odes/ 中的示例,我试图使用 google jax 模仿 scipy 和 scipy.optimize.curve_fit
中的曲线拟合函数。要拟合的函数是一阶 ODE。
#Generate toy data for first order ode.
import jax.numpy as jnp
import jax
import numpy as np
#input data
u = np.zeros(100)
u[10:50] = 1
t = np.arange(len(u))
u = jnp.array(u)
#first order ODE
def f(y,t,k,tau,u):
return (k*u[t]-y)/tau
#Euler integration
def odeint_euler(f, y0, t, *args):
def step(state, t):
y_prev, t_prev = state
dt = t - t_prev
y = y_prev + dt * f(y_prev, t_prev, *args)
return (y, t), y
_, ys = jax.lax.scan(step, (y0, t[0]), t[1:])
return ys
pred = odeint_euler(f, jnp.array([0.0]),t,2.,5.,u)
pred_noise = pred.reshape(-1) + 0.05* np.random.randn(len(pred)) # this is the data to be fitted
# define loss function
def loss_function(params,u,targets):
k,tau = params
pred = odeint_euler(f, jnp.array([0.0]),t,k,tau,u)
return jnp.sum((pred-targets)**2)
def update(params, u, targets):
grads = jax.grad(loss_function)(params,u, targets)
return [w - 0.0001 * dw for w,dw in zip(params, grads)]
updated_params = jnp.array([1.0,2.0]) #initial parameters
for i in range(100):
updated_params = update(updated_params, u, pred_noise)
print(updated_params)
代码运行良好。然而,与 scipy 曲线拟合相比,这个 运行 非常慢。即使在 500、1000 次迭代之后,解决方案的准确性也不好。
上面的代码有什么问题?知道如何使代码 运行 更快并获得更准确的解决方案吗? jax
有没有更好的曲线拟合方法?
我发现您的方法存在两个总体问题:
- 您的代码 运行 运行缓慢的原因是您在 Python 中进行循环,这会在每个循环中产生 JAX 的调度开销。我建议使用 JAX 的内置工具来最小化损失函数;例如:
from jax.scipy.optimize import minimize
result = minimize(
loss_function, x0=jnp.array([1.0,2.0]),
method='BFGS', args=(u, pred_noise))
- 您的准确度不接近 scipy 的原因可能是因为 JAX 默认为 32 位计算(参见 Double (64 bit) Precision)。要 运行 您的 64 位代码,您可以 运行 此块在任何其他导入之前:
from jax import config
config.update('jax_enable_x64', True)