如何使用 google jax 进行曲线拟合?

how to do curve fitting using google jax?

扩展 http://implicit-layers-tutorial.org/neural_odes/ 中的示例,我试图使用 google jax 模仿 scipy 和 scipy.optimize.curve_fit 中的曲线拟合函数。要拟合的函数是一阶 ODE。

#Generate toy data for first order ode.

import jax.numpy as jnp
import jax
import numpy as np


#input  data 
u = np.zeros(100)  
u[10:50] = 1
t = np.arange(len(u))
u = jnp.array(u)

#first order ODE
def f(y,t,k,tau,u):
 
  return (k*u[t]-y)/tau
  
#Euler integration
def odeint_euler(f, y0, t, *args):
  def step(state, t):
    y_prev, t_prev = state
    dt = t - t_prev
    y = y_prev + dt * f(y_prev, t_prev, *args)
    return (y, t), y
  _, ys = jax.lax.scan(step, (y0, t[0]), t[1:])
  return ys

pred = odeint_euler(f, jnp.array([0.0]),t,2.,5.,u) 
pred_noise = pred.reshape(-1) +  0.05* np.random.randn(len(pred)) # this is the  data to be fitted

# define loss function 
def loss_function(params,u,targets):
  k,tau = params
  
  pred = odeint_euler(f, jnp.array([0.0]),t,k,tau,u)
  return jnp.sum((pred-targets)**2)      


def update(params, u, targets):
  grads = jax.grad(loss_function)(params,u, targets)
  return [w - 0.0001 * dw for w,dw  in zip(params, grads)] 


updated_params = jnp.array([1.0,2.0]) #initial parameters
for i in range(100):
  updated_params = update(updated_params, u, pred_noise)
print(updated_params)

代码运行良好。然而,与 scipy 曲线拟合相比,这个 运行 非常慢。即使在 500、1000 次迭代之后,解决方案的准确性也不好。 上面的代码有什么问题?知道如何使代码 运行 更快并获得更准确的解决方案吗? jax有没有更好的曲线拟合方法?

我发现您的方法存在两个总体问题:

  1. 您的代码 运行 运行缓慢的原因是您在 Python 中进行循环,这会在每个循环中产生 JAX 的调度开销。我建议使用 JAX 的内置工具来最小化损失函数;例如:
from jax.scipy.optimize import minimize
result = minimize(
    loss_function, x0=jnp.array([1.0,2.0]),
    method='BFGS', args=(u, pred_noise))
  1. 您的准确度不接近 scipy 的原因可能是因为 JAX 默认为 32 位计算(参见 Double (64 bit) Precision)。要 运行 您的 64 位代码,您可以 运行 此块在任何其他导入之前:
from jax import config
config.update('jax_enable_x64', True)