julia:避免许多函数参数的高效而整洁的方法

julia: efficient and tidy way to avoid many function arguments

我一直在用 Julia 编写随机 PDE 模拟,随着我的问题变得越来越复杂,独立参数的数量也越来越多。那么什么开头,

myfun(N,M,dt,dx,a,b)

最终变成

myfun(N,M,dt,dx,a,b,c,d,e,f,g,h)

这会导致 (1) 代码混乱,(2) 由于函数参数放错地方而导致错误的可能性增加,(3) 无法泛化以用于其他函数。

(3) 很重要,因为我对我的代码进行了简单的并行化以评估 PDE 的许多不同运行。所以我想将我的函数转换成一种形式:

myfun(args)

其中 args 包含所有相关参数。我在 Julia 中发现的问题是,创建一个包含我所有相关参数作为属性的 struct 会大大减慢速度。我认为这是由于对结构属性的持续访问。作为一个简单的 (ODE) 工作示例,

function example_fun(N,dt,a,b)
  V = zeros(N+1)
  U = 0
  z = randn(N+1)
  for i=2:N+1
    V[i] = V[i-1]*(1-dt)+U*dt
    U = U*(1-dt/a)+b*sqrt(2*dt/a)*z[i]
  end
  return V
end

如果我尝试将其重写为,

function example_fun2(args)
  V = zeros(args.N+1)
  U = 0
  z = randn(args.N+1)
  for i=2:args.N+1
    V[i] = V[i-1]*(1-args.dt)+U*args.dt
    U = U*(1-args.dt/args.a)+args.b*sqrt(2*args.dt/args.a)*z[i]
  end
  return V
end

然后,虽然函数调用看起来很优雅,但从 class 中重新访问每个属性是很麻烦的,而且这种持续访问属性会减慢模拟速度。什么是更好的解决方案?有没有一种方法可以简单地 'unpack' 结构的属性,这样就不必不断访问它们?如果是这样,这将如何推广?

编辑: 我定义我使用的结构如下:

struct Args
  N::Int64
  dt::Float64
  a::Float64
  b::Float64
end

edit2:我已经意识到,如果您没有在结构定义中指定数组的维度,具有 Array{} 属性的结构会导致性能差异。比如c是参数的一维数组,

struct Args_1
  N::Int64
  c::Array{Float64}
end

在 f(args) 中的性能将比 f(N,c) 差得多。但是,如果我们在struct定义中指定c为一维数组,

struct Args_1
  N::Int64
  c::Array{Float64,1}
end

然后性能损失消失。这个问题和我的函数定义中显示的类型不稳定性似乎解释了我在使用结构作为函数参数时遇到的性能差异。

可能你没有在args的类型声明中声明参数的类型?

考虑这个小例子:

struct argstype
    N
    dt
end
myfun(args) = args.N * args.dt

myfun 类型不稳定 无法推断 return 类型的类型:

@code_warntype myfun(argstype(10,0.1))
Variables:
  #self# <optimized out>
  args::argstype

Body:
  begin 
      return ((Core.getfield)(args::argstype, :N)::Any * (Core.getfield)(args::argstype, :dt)::Any)::Any
  end::Any

但是,如果您声明了类型,那么代码就会变成类型稳定的:

struct argstype2
    N::Int
    dt::Float64
end

@code_warntype myfun(argstype2(10,0.1))
Variables:
  #self# <optimized out>
  args::argstype2

Body:
  begin 
      return (Base.mul_float)((Base.sitofp)(Float64, (Core.getfield)(args::argstype2, :N)::Int64)::Float64, (Core.getfield)(args::argstype2, :dt)::Float64)::Float64
  end::Float64

您看到推断出 return 类型的 Float64。 使用参数类型 (https://docs.julialang.org/en/v0.6.3/manual/types/#Parametric-Types-1),您的代码仍然同时保持通用和类型稳定:

struct argstype3{T1,T2}
    N::T1
    dt::T2
end

@code_warntype myfun(argstype3(10,0.1))
Variables:
  #self# <optimized out>
  args::argstype3{Int64,Float64}

Body:
  begin 
      return (Base.mul_float)((Base.sitofp)(Float64, (Core.getfield)(args::argstype3{Int64,Float64}, :N)::Int64)::Float64, (Core.getfield)(args::argstype3{Int64,Float64}, :dt)::Float64)::Float64
  end::Float64

在您的代码中存在类型不稳定性,与初始化为 0(整数)的 U 相关,但如果将其替换为 0(浮点数),则类型不稳定性消失。

对于原始版本("U=0"),函数 example_fun 需要 801.933 ns(对于参数 10,0.1,2.,3.)和 example_fun2 925.323 ns (对于相似的值)。

在类型稳定版本 (U=0.) 中,两者都需要 273 ns (+/5 ns)。因此,这是一个实质性的加速,并且不再有在类型参数中组合参数的惩罚。

完整函数如下:

function example_fun2(args)
     V = zeros(args.N+1)
     U = 0.
     z = randn(args.N+1)
     for i=2:args.N+1
       V[i] = V[i-1]*(1-args.dt)+U*args.dt
       U = U*(1-args.dt/args.a)+args.b*sqrt(2*args.dt/args.a)*z[i]
     end
     return V
end