有没有办法为 JiTCODE 的函数参数提供数值函数而不是符号函数?

Is there a way to supply a numerical function to JiTCODE’s function argument instead of symbolic one?

我正在通过神经网络获取一个函数(学习动力系统),并希望将其传递给 JiTCODE 以计算轨迹、Lyapunov 指数等。根据 JiTCODE 文档,函数 f 必须是符号函数。有什么办法可以改变这一点,因为最终 JiTCODE 将对符号函数进行 lambdify 化?

基本上,这就是我现在正在做的事情:

# learns derviates from the Neural net model
# returns an array of numbers [\dot{x},\dot{y}] for input [x,y]
learned_fn = lambda t, y0: NN_model(t, y0) 

ODE = jitcode_lyap(learned_fn, n_lyap=2)
ODE.set_integrator("vode")

直接引用链接文档

JiTCODE takes an iterable (or generator function or dictionary) of symbolic expressions, which it translates to C code, compiles on the fly,

所以没有进行 lambdification,函数被解析,而不仅仅是求值。

但一般来说应该没问题,你只要用JITCODE提供的符号向量y和符号t代替右边的函数参数t,y ODE.

首先请注意,JiTCODE 不会像您的 learned_fn 这样的常规函数​​作为输入。它采用符号表达式的迭代或返回符号表达式的生成器函数。这就是您的示例代码可能会产生错误的原因。

你想要什么

您可以通过更改 f 属性 并告诉它编译实际导数失败,将任何具有正确签名的导数“注入”到 JiTCODE 中。这是一个最小的例子:

from jitcode import jitcode, y

ODE = jitcode([0])
ODE.f = lambda t,y: y[0]
ODE.compile_attempt = False
ODE.set_integrator("dopri5")
ODE.set_initial_value([1],0.0)

for time in range(30):
    print(time,*ODE.integrate(time))

为什么您可能不想这样做

暂时忽略 Lyapunov 指数,JiTCODE 的全部意义在于为您硬编码导数并将其传递给 SciPy 的 odesolve_ivp 执行实际整合。因此,上面的示例代码只是将函数传递给 SciPy 的标准集成器(此处为 ode)的一种过于复杂的方式,没有任何优势。如果您的 NN_model 一开始就非常有效地实现,您甚至可能无法从 JiTCODE 的自动编译中获得速度提升。

使用 JiTCODE 的 Lyapunov 指数功能的主要原因是它自动从导数的符号表示中获取切线向量演化的 Jacobian 和 ODE(Benettin 方法需要)。没有符号输入,它不可能做到这一点。理论上您也可以注入一个切向量 ODE,但是您将再一次为 JiTCODE 留下很少的空间,您最好直接使用 SciPy 的 odesolve_ivp

您可能需要什么

如果您想使用 JiTCODE,您需要编写一小段代码,将神经网络训练的输出转换为 JiTCODE 所需的 ODE 的符号表示。这可能没有听起来那么可怕。你只需要得到训练好的系数,代入到神经网络一般形式的方程中即可。

如果你很幸运并且你的NN_model完全支持duck typing(和),你可以这样做:

from jitcode import t,y
n = 10 # dimension of your ODE
NN_input = [y(i) for i in range(n)]
learned_fn = NN_model(t,NN_input)[1]

我们的想法是,您使用抽象符号输入(tNN_input)为 NN_model 提供一次。 NN_model 然后一旦作用于这个抽象输入,就会为您提供一个抽象结果(这里您需要 duck-typing 支持)。如果我正确地解释了你的 NN_model 的输出,那么这个结果的第二个组成部分应该是 JiTCODE 要求的抽象导数作为输入。

请注意,您的 NN_model 似乎期望维度是索引,但 JiTCODE 的 y 期望维度是函数参数。所以你不能只选择NN_input = y,你必须像上面那样转换它。