用于求解同步 ODE 的并行化代码 (DifferentialEquations.jl) - Julia

Parallelizing code for solving simultaneous ODEs (DifferentialEquations.jl) - Julia

我有以下 ODE 耦合系统(来自积分微分 PDE 的离散化):

xi 是我控制的 x 网格上的点。我可以用下面的简单代码解决这个问题:

using DifferentialEquations

function ode_syst(du,u,p, t)
    N = Int64(p[1])
    beta= p[2]
    deltax = 1/(N+1)
    xs = [deltax*i for i in 1:N]
    for j in 1:N
        du[j] = -xs[j]^(beta)*u[j]+deltax*sum([u[i]*xs[i]^(beta) for i in 1:N])
    end
end

N = 1000
u0 = ones(N)
beta = 2.0
p = [N, beta]
tspan = (0.0, 10^3);

prob = ODEProblem(ode_syst,u0,tspan,p);
sol = solve(prob);

然而,当我使我的网格更细时,即增加 N,计算时间迅速增长(我猜缩放比例是 N 的二次方)。对于如何使用分布式并行或多线程来实现它有什么建议吗?

附加信息:我附上了分析图,它可能有助于了解程序大部分时间花在哪里

分析公式。显然,原子术语重复了。所以它们应该只计算一次。

function ode_syst(du,u,p, t)
    N = Int64(p[1])
    beta= p[2]
    deltax = 1/(N+1)
    xs = [deltax*i for i in 1:N]
    term = [ xs[i]^(beta)*u[i] for i in 1:N]
    term_sum = deltax*sum(term)
    for j in 1:N
        du[j] = -term[j]+term_sum
    end
end

这应该只会在 N 中线性增加。

我查看了您的代码并发现了一些问题,例如由于重新计算总和项而意外引入了 O(N^2) 行为。

我的改进版本使用 Tullio 包来进一步加快矢量化速度。 Tullio 还具有可调参数,如果您的系统变得足够大,这些参数将允许多线程。在这里查看您可以调整哪些参数 in the options section。您可能还会在那里看到 GPU 支持,我没有对此进行测试,但它可能会进一步加速或崩溃。我还选择从 acutal 数组中获取长度,这应该使使用更经济且更不容易出错。

using Tullio

function ode_syst_t(du,u,p, t)
    N = length(du)
    beta = p[1]
    deltax = 1/(N+1)
    @tullio s := deltax*(u[i]*(i*deltax)^(beta))
    @tullio du[j] = -(j*deltax)^(beta)*u[j] + s
    return nothing
end

您的代码:

 @btime sol = solve(prob);
  80.592 s (1349001 allocations: 10.22 GiB)

我的代码:

prob2 = ODEProblem(ode_syst_t,u0,tspan,[2.0]);
@btime sol2 = solve(prob2);
  1.171 s (696 allocations: 18.50 MiB)

结果基本一致:

julia> sum(abs2, sol2(1000.0) .- sol(1000.0))
1.079046922815598e-14

我还对 Lutz Lehmanns 解决方案进行了基准测试:

prob3 = ODEProblem(ode_syst_lehm,u0,tspan,p);

@btime sol3 = solve(prob3);
  1.338 s (3348 allocations: 39.38 MiB)

然而,当我们将 N 缩放到 1000000 时,tspan 为 (0.0, 10.0)

prob2 = ODEProblem(ode_syst_t,u0,tspan,[2.0]);

@time solve(prob2);
  2.429239 seconds (280 allocations: 869.768 MiB, 13.60% gc time)

prob3 = ODEProblem(ode_syst_lehm,u0,tspan,p);

@time solve(prob3);
  5.961889 seconds (580 allocations: 1.967 GiB, 11.08% gc time)

由于在我生锈的旧机器中使用了 2 个内核,我的代码速度提高了两倍多。