Julia 中的批量矩阵乘法

Batch matrix multiplication in Julia

我正在尝试在 Julia 中将 N 维 (N>=3) 数组乘以矩阵批次,即沿最后两个维度执行矩阵乘法,同时保持其他维度不变。

例如,如果 x 的维度为 (d1,d2,4,3),而 y 的维度为 (d1,d2,3,2),则乘法的结果应为 (d1,d2,4,2),即应执行一批矩阵乘法。

这正是 Python 中发生的事情 numpy.matmul:

If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.

np.matmul(randn(10,10,4,3), randn(10,10,3,2)).shape
(10, 10, 4, 2)

有没有办法在 Julia 中重现 numpy.matmul 的行为?

我希望 .* 会起作用,但是:

julia> randn(10,10,4,3) .* randn(10,10,3,2)
ERROR: DimensionMismatch("arrays could not be broadcast to a common size")
Stacktrace:
 [1] _bcs1 at ./broadcast.jl:485 [inlined]
 [2] _bcs at ./broadcast.jl:479 [inlined] (repeats 3 times)
 [3] broadcast_shape at ./broadcast.jl:473 [inlined]
 [4] combine_axes at ./broadcast.jl:468 [inlined]
 [5] instantiate at ./broadcast.jl:256 [inlined]
 [6] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{4},Nothing,typeof(*),Tuple{Array{Float64,4},Array{Float64,4}}}) at ./broadcast.jl:798
 [7] top-level scope at REPL[80]:1

我知道列表推导式可能适用于 3-D,但这在更高维度上会变得非常混乱。重塑(或查看)除最后 2 个维度之外的所有维度、使用列表理解并将其重塑回来的最佳解决方案是什么?或者有更好的方法吗?

P.S。我能找到的最接近的是 this,但它并不完全相同。 Julia 的新手,因此可能会遗漏一些对 Julia 用户来说很明显的东西。

我不知道有任何这样的功能,但在某些程序包中很可能有。我认为在 Julia 中,将数据组织为矩阵数组并在其上广播矩阵乘法更为自然:

D = [rand(50, 60) for i in 1:4, j in 1:3]
E = [rand(60, 70) for i in 1:4, j in 1:3]
D .* E  # now you can use dot broadcasting!

话虽如此,制作您自己的产品很容易。不过,我会做出一点改变。 Julia 是主要的列,而 numpy 是 "last dimension major",因此你应该让矩阵位于 第一个 两个维度,而不是最后两个

首先,我将定义一个乘以数组 C 的就地方法,然后定义一个调用就地版本的非就地方法(我将跳过维度检查等):

# In-place version, note the use of the @views macro, 
# which is essential to get in-place behaviour

using LinearAlgebra: mul!  # fast in-place matrix multiply

function batchmul!(C, A, B)
    for j in axes(A, 4), i in axes(A, 3)
        @views mul!(C[:, :, i, j], A[:, :, i, j], B[:, :, i, j])
    end
    return C
end

# The non-in-place version
function batchmul(A, B)
    T = promote_type(eltype(A), eltype(B))
    C = Array{T}(undef, size(A, 1), size(B)[2:end]...)
    return batchmul!(C, A, B)
end

你也可以让它成为多线程的。 在我的计算机上,4 个线程提供了 2.5 倍的加速(实际上,对于最后两个维度的较大值,我获得了 3.5 倍的加速)你获得多少加速取决于大小和所涉及阵列的形状:

function batchmul!(C, A, B)
    Threads.@threads for j in axes(A, 4)
        for i in axes(A, 3)
            @views mul!(C[:, :, i, j], A[:, :, i, j], B[:, :, i, j])
        end
    end
    return C
end

编辑: 我刚才注意到您需要一般 N-D,而不仅仅是 4-D。不过,不应该太难概括。无论如何,更有理由选择矩阵数组,其中广播将自动适用于所有维度。

Edit2: 不能离开它,所以这是 N-D 情况下的一个(还有更多工作要做,比如处理非基于 1 的索引(更新:axes 应该解决这个问题)):

function batchmul!(C, A, B)
    Threads.@threads for I in CartesianIndices(axes(A)[3:end])
        @views mul!(C[:, :, Tuple(I)...], A[:, :, Tuple(I)...], B[:, :, Tuple(I)...])
    end
    return C
end

对于 N=3,您正在寻找 NNlib.batched_mul。请注意(如上所述)Julia 的数组是列优先的,因此将最后一个索引而不是第一个索引视为批处理中的 运行 通常是有意义的:

julia> using NNlib

julia> randn(4,3,100) ⊠ randn(3,2,100)
4×2×100 Array{Float64, 3}:
[:, :, 1] =
  0.9292     -0.223521
...

这只是一个类似batchmul!(C, A, B)的循环,但它也会为GPU CuArray调用适当的库函数。

将其扩展到超过 3 个维度并不难,但必须有人去做,并决定规则。对于第 3 个维度,它的行为类似于广播:

julia> randn(4,3,100) ⊠ randn(3,2) |> size
(4, 2, 100)

julia> randn(4,3,100) ⊠ randn(3,2,1) |> size
(4, 2, 100)

julia> try randn(4,3,100) ⊠ randn(3,2,2) catch e println(e) end
DimensionMismatch("batch size mismatch: A != B")