合并 scipy.root 和 Jax 雅可比矩阵

Combine scipy.root and Jax Jacobian

我在将 JAX 中的 Jacobian 与 scipy.root 结合使用时遇到问题。在下面的示例中,root 在没有 Jacobian 的情况下工作,而在 Jacobian 的情况下失败。关于我需要重写什么才能让下面的代码与雅可比行列式一起工作,有什么想法吗?

from jax import jacfwd
from scipy.optimize import root
import numpy as np

def objectFunction(valuesEndo, varNamesEndo, valuesExo, varNamesExo, equations): 
  for i in range(len(varNamesExo)):
      exec("%s = %.10f" %(varNamesExo[i], valuesExo[i]))

  for i in range(len(varNamesEndo)):
    exec("%s = %.10f" %(varNamesEndo[i], valuesEndo[i]))
    
  equationVector = np.zeros(len(equations))
  for i in range(len(equations)):
      exec('equationVector[%d] = eval(equations[%d])' %(i, i))    
      
  return equationVector

varNamesEndo = ['x', 'y']
valuesEndoInitialGuess = [1., 1.]

varNamesExo = ['a', 'b']
valuesExo = [1., 1.]

equations = ['a*x+b*y**2-4',
            'np.exp(x) + x*y - 3']

equations = ['a*x**2 + b*y**2',
            'a*x**2 - b*y**2']

# Without Jacobian
sol1 =  root(fun=objectFunction,
            x0=valuesEndoInitialGuess, 
            args=(varNamesEndo, valuesExo, varNamesExo, equations))
#----> Works

# With Jacobian
jac  = jacfwd(objectFunction)
sol2 =  root(fun=objectFunction,
            x0=valuesEndoInitialGuess, 
            args=(varNamesEndo, valuesExo, varNamesExo, equations),
            jac=jac)
#----> Not woring

至少线路好像有问题

for i in range(len(varNamesEndo)):
        exec("%s = %.10f" %(varNamesEndo[i], valuesEndo[i]))

有两个问题:

  1. 为了执行自动微分,JAX 依赖于用跟踪器替换值。这意味着您打印和评估值的字符串表示的方法将不起作用。
  2. 此外,您正在尝试将跟踪值分配给标准 numpy 数组。您应该改用 JAX 数组,因为它知道如何处理跟踪值。

考虑到这一点,您可以这样重写您的函数并且它应该可以工作,只要您的方程式仅使用 Python 算术运算和 jax 函数(而不是 np.exp 之类的东西):

import jax.numpy as jnp

def objectFunction(valuesEndo, varNamesEndo, valuesExo, varNamesExo, equations): 
  for i in range(len(varNamesExo)):
      exec("%s = valuesExo[%d]" %(varNamesExo[i], i))

  for i in range(len(varNamesEndo)):
    exec("%s = valuesEndo[%d]" %(varNamesEndo[i], i))
    
  equationVector = jnp.zeros(len(equations))
  for i in range(len(equations)):
      equationVector = equationVector.at[i].set(eval(equations[i]))
      
  return equationVector

旁注:这种基于使用 exec 设置变量名的方法非常脆弱且容易出错;我建议采用一种基于构建显式命名空间来评估方程式的方法。例如这样的事情:

def objectFunction(valuesEndo, varNamesEndo, valuesExo, varNamesExo, equations):
  namespace = {
      **dict(zip(varNamesEndo, valuesEndo)),
      **dict(zip(varNamesExo, valuesExo))
  }
  return jnp.array([eval(eqn, namespace) for eqn in equations])