如何在 Julia 中加速多个广播

How to speedup multiple broadcasts in Julia

这个 Julia 函数似乎效率很低(比等效的 Pythran / C++ 代码慢一个数量级,即使在 Julia 预热之后)...

function my_multi_broadcast(a)
    10 * (2*a.^2 + 4*a.^3) + 2 ./ a
end

arr = ones(1000, 1000)
my_multi_broadcast(arr)

我猜只是我写得不正确...如何在 Julia 中实现这样的 "multi broadcasts" 加速?我 guess/hope 我不需要扩展循环...

在第一个答案后编辑

谢谢!通过我的设置,Pythran 解决方案(就地和不就地)仍然快 1.5 到 2 倍(没有 OpenMP)。有没有办法在 Julia 中激活 SIMD 指令?或者另一种加速这种 CPU 计算的方法?

Python代码:

from transonic import jit

@jit
def broadcast(a):
    return 10 * (2*a**2 + 4*a**3) + 2 / a

@jit
def broadcast_inplace(a):
    a[:] = 10 * (2*a**2 + 4*a**3) + 2 / a

根据@simd建议进行编辑

似乎 @simd 不能开箱即用,即只需将其添加到行首即可。

ERROR: LoadError: LoadError: Base.SimdLoop.SimdError("for loop expected")
Stacktrace:
 [1] compile(::Expr, ::Bool) at ./simdloop.jl:54
 [2] @simd(::LineNumberNode, ::Module, ::Any) at ./simdloop.jl:126
 [3] include at ./boot.jl:317 [inlined]
 [4] include_relative(::Module, ::String) at ./loading.jl:1044
 [5] include(::Module, ::String) at ./sysimg.jl:29
 [6] exec_options(::Base.JLOptions) at ./client.jl:231
 [7] _start() at ./client.jl:425

我猜想必须扩展 for 循环,但随后代码 (i) 的可读性大大降低,并且 (ii) 不再独立于维度。

似乎我们有一个例子,简单的 Python/Numpy 代码可以用 Pythran 比我们用 Julia 更快地加速(除非有一种方法可以在 Julia 中加速它?和未来的 Julia版本可能会解决这个问题)。有意思...

像这样广播所有操作:

julia> function my_multi_broadcast2(a)
           @. 10 * (2*a^2 + 4*a^3) + 2 / a
       end
my_multi_broadcast2 (generic function with 1 method)

不同之处在于,在 10 * (2*a.^2 + 4*a.^3) + 2 ./ a 中,您实际上没有利用广播融合,因为 * 并且两个 + 没有广播。

@. 10 * (2*a^2 + 4*a^3) + 2 / a等同于10 .* (2 .* a.^2 .+ 4 .* a.^3) .+ 2 ./ a.

下面是性能对比

julia> @btime my_multi_broadcast($arr);
  58.146 ms (18 allocations: 61.04 MiB)

julia> @btime my_multi_broadcast2($arr);
  5.982 ms (4 allocations: 7.63 MiB)

它与 Pythran / C++ 相比如何,因为我们获得了大约 10 倍的加速?

最后请注意,如果您可以通过以下方式改变 arr

julia> function my_multi_broadcast3(a)
           @. a = 10 * (2*a^2 + 4*a^3) + 2 / a
       end
my_multi_broadcast3 (generic function with 1 method)

julia> @btime my_multi_broadcast3($arr);
  1.840 ms (0 allocations: 0 bytes)

速度更快并且分配为零(我不知道你是想就地修改 arr 还是创建一个新数组,所以我展示了这两种方法)。