如何做曲线拟合使用谷歌jax?



扩展http://implicit-layers-tutorial.org/neural_odes/中的示例,我试图使用google jax模拟scipy,scipy.optimize.curve_fit中的曲线拟合函数。拟合函数为一阶ODE。

#Generate toy data for first order ode.
import jax.numpy as jnp
import jax
import numpy as np

#input  data 
u = np.zeros(100)  
u[10:50] = 1
t = np.arange(len(u))
u = jnp.array(u)
#first order ODE
def f(y,t,k,tau,u):

return (k*u[t]-y)/tau

#Euler integration
def odeint_euler(f, y0, t, *args):
def step(state, t):
y_prev, t_prev = state
dt = t - t_prev
y = y_prev + dt * f(y_prev, t_prev, *args)
return (y, t), y
_, ys = jax.lax.scan(step, (y0, t[0]), t[1:])
return ys
pred = odeint_euler(f, jnp.array([0.0]),t,2.,5.,u) 
pred_noise = pred.reshape(-1) +  0.05* np.random.randn(len(pred)) # this is the  data to be fitted
# define loss function 
def loss_function(params,u,targets):
k,tau = params

pred = odeint_euler(f, jnp.array([0.0]),t,k,tau,u)
return jnp.sum((pred-targets)**2)      

def update(params, u, targets):
grads = jax.grad(loss_function)(params,u, targets)
return [w - 0.0001 * dw for w,dw  in zip(params, grads)] 

updated_params = jnp.array([1.0,2.0]) #initial parameters
for i in range(100):
updated_params = update(updated_params, u, pred_noise)
print(updated_params)

代码工作正常。然而,与scipy曲线拟合相比,这运行得相当慢。即使经过500、1000次迭代,解决方案的精度也不是很好。上面的代码有什么问题?任何想法如何使代码运行更快,并得到更准确的解决方案?是否有更好的方法用jax进行曲线拟合?

我看到你的方法有两个总体问题:

  1. 你的代码运行缓慢的原因是因为你在Python中做循环,这会导致每个循环的JAX调度开销。我建议使用JAX的内置工具来最小化损失函数;例如:
from jax.scipy.optimize import minimize
result = minimize(
loss_function, x0=jnp.array([1.0,2.0]),
method='BFGS', args=(u, pred_noise))
  1. 您的精度没有接近scipy的原因可能是因为JAX默认为32位计算(参见Double(64位)Precision)。要以64位运行代码,您可以在任何其他导入之前运行此块:
from jax import config
config.update('jax_enable_x64', True)

最新更新