使用张量输入在张量流中求解 ODE

Solve ODE in tensorflow with tensor inputs

我正在尝试求解跨不同常量的同一 ODE 的许多实例。

这是我的代码:

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

class SimpleODEModule(tf.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        
    def __call__(self, t_initial, x_initial, solution_times, parameters):
        with tf.GradientTape() as tape:
            tape.watch(parameters)
            solution = tfp.math.ode.BDF().solve(
                                    self.ode_system, 
                                    t_initial,
                                    x_initial,
                                    solution_times,
                                    constants={'parameters': parameters})
            tape.gradient(solution.states, parameters)
        return solution.states
    
    def ode_system(self, t, x, parameters):
        a = parameters[:, 0]
        b = parameters[:, 1]
        dx = tf.add(tf.multiply(b, tf.exp(tf.multiply(a, t))), tf.multiply(a, x))
        print(dx)
        return dx

constants = tf.constant([[1.0, 2.0],[3.0, 4.0], [5.0, 6.0]], dtype=tf.float32)
t_initial = tf.reshape(tf.cast(tf.repeat(0.0, constants.shape[0]), dtype=tf.float32), (1,constants.shape[0]))
x_initial = tf.reshape(tf.cast(tf.repeat(0.0, constants.shape[0]), dtype=tf.float32), (1,constants.shape[0]))
solution_times = tf.cast(tf.repeat(1.0, constants.shape[0]), dtype=tf.float32)

simple_ode = SimpleODEModule()

# This causes an error deep down int tfp.ode
# The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
simple_ode(t_initial, x_initial, solution_times, constants)

# Returns the expected output x(1.0) for each set of constants
simple_ode.ode_system(t_initial, x_initial, constants)

我是张量流的新手,所以我想我没有在某处创建正确形状的张量。我希望这能“正常工作”,迭代张量的维度以针对每组常量多次求解 ODE。感谢任何帮助。

我找到了解决办法。虽然我不确定这是最好的。我没有对 tf.Module 进行子类化,而是对 tf.keras.layers.Layer 进行子类化,它“刚刚好”。这是代码中的更改:

class ODELayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs, ode_system):
        super(ODELayer, self).__init__()
        self.num_outputs = num_outputs
        self.ode_system = ode_system

    def call(self, input_tensor):
        return tf.map_fn(self.solve_ode, input_tensor)
    
    def solve_ode(self, parameters):
        with tf.GradientTape() as tape:
            tape.watch(parameters)
            solution = tfp.math.ode.BDF().solve(
                    self.ode_system,
                    0.0, 0.0, [1.0],
                    constants={'parameters': parameters}
                )
            tape.gradient(solution.states, parameters)
        return solution.states
    
def simple_ode(t, x, parameters):
    a = parameters[0]
    b = parameters[1]
    dx = tf.add(tf.multiply(b, tf.exp(tf.multiply(a, t))), tf.multiply(a, x))
    return dx

感谢所有查看过此内容或尝试过解决方案的人。