Julia 中的批量矩阵乘法
Batch matrix multiplication in Julia
我正在尝试在 Julia 中将 N 维 (N>=3) 数组乘以矩阵批次,即沿最后两个维度执行矩阵乘法,同时保持其他维度不变。
例如,如果 x
的维度为 (d1,d2,4,3)
,而 y
的维度为 (d1,d2,3,2)
,则乘法的结果应为 (d1,d2,4,2)
,即应执行一批矩阵乘法。
这正是 Python 中发生的事情 numpy.matmul
:
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.
np.matmul(randn(10,10,4,3), randn(10,10,3,2)).shape
(10, 10, 4, 2)
有没有办法在 Julia 中重现 numpy.matmul
的行为?
我希望 .*
会起作用,但是:
julia> randn(10,10,4,3) .* randn(10,10,3,2)
ERROR: DimensionMismatch("arrays could not be broadcast to a common size")
Stacktrace:
[1] _bcs1 at ./broadcast.jl:485 [inlined]
[2] _bcs at ./broadcast.jl:479 [inlined] (repeats 3 times)
[3] broadcast_shape at ./broadcast.jl:473 [inlined]
[4] combine_axes at ./broadcast.jl:468 [inlined]
[5] instantiate at ./broadcast.jl:256 [inlined]
[6] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{4},Nothing,typeof(*),Tuple{Array{Float64,4},Array{Float64,4}}}) at ./broadcast.jl:798
[7] top-level scope at REPL[80]:1
我知道列表推导式可能适用于 3-D,但这在更高维度上会变得非常混乱。重塑(或查看)除最后 2 个维度之外的所有维度、使用列表理解并将其重塑回来的最佳解决方案是什么?或者有更好的方法吗?
P.S。我能找到的最接近的是 this,但它并不完全相同。 Julia 的新手,因此可能会遗漏一些对 Julia 用户来说很明显的东西。
我不知道有任何这样的功能,但在某些程序包中很可能有。我认为在 Julia 中,将数据组织为矩阵数组并在其上广播矩阵乘法更为自然:
D = [rand(50, 60) for i in 1:4, j in 1:3]
E = [rand(60, 70) for i in 1:4, j in 1:3]
D .* E # now you can use dot broadcasting!
话虽如此,制作您自己的产品很容易。不过,我会做出一点改变。 Julia 是主要的列,而 numpy 是 "last dimension major",因此你应该让矩阵位于 第一个 两个维度,而不是最后两个
首先,我将定义一个乘以数组 C
的就地方法,然后定义一个调用就地版本的非就地方法(我将跳过维度检查等):
# In-place version, note the use of the @views macro,
# which is essential to get in-place behaviour
using LinearAlgebra: mul! # fast in-place matrix multiply
function batchmul!(C, A, B)
for j in axes(A, 4), i in axes(A, 3)
@views mul!(C[:, :, i, j], A[:, :, i, j], B[:, :, i, j])
end
return C
end
# The non-in-place version
function batchmul(A, B)
T = promote_type(eltype(A), eltype(B))
C = Array{T}(undef, size(A, 1), size(B)[2:end]...)
return batchmul!(C, A, B)
end
你也可以让它成为多线程的。 在我的计算机上,4 个线程提供了 2.5 倍的加速(实际上,对于最后两个维度的较大值,我获得了 3.5 倍的加速)你获得多少加速取决于大小和所涉及阵列的形状:
function batchmul!(C, A, B)
Threads.@threads for j in axes(A, 4)
for i in axes(A, 3)
@views mul!(C[:, :, i, j], A[:, :, i, j], B[:, :, i, j])
end
end
return C
end
编辑: 我刚才注意到您需要一般 N-D,而不仅仅是 4-D。不过,不应该太难概括。无论如何,更有理由选择矩阵数组,其中广播将自动适用于所有维度。
Edit2: 不能离开它,所以这是 N-D 情况下的一个(还有更多工作要做,比如处理非基于 1 的索引(更新:axes
应该解决这个问题)):
function batchmul!(C, A, B)
Threads.@threads for I in CartesianIndices(axes(A)[3:end])
@views mul!(C[:, :, Tuple(I)...], A[:, :, Tuple(I)...], B[:, :, Tuple(I)...])
end
return C
end
对于 N=3,您正在寻找 NNlib.batched_mul
。请注意(如上所述)Julia 的数组是列优先的,因此将最后一个索引而不是第一个索引视为批处理中的 运行 通常是有意义的:
julia> using NNlib
julia> randn(4,3,100) ⊠ randn(3,2,100)
4×2×100 Array{Float64, 3}:
[:, :, 1] =
0.9292 -0.223521
...
这只是一个类似batchmul!(C, A, B)
的循环,但它也会为GPU CuArray
调用适当的库函数。
将其扩展到超过 3 个维度并不难,但必须有人去做,并决定规则。对于第 3 个维度,它的行为类似于广播:
julia> randn(4,3,100) ⊠ randn(3,2) |> size
(4, 2, 100)
julia> randn(4,3,100) ⊠ randn(3,2,1) |> size
(4, 2, 100)
julia> try randn(4,3,100) ⊠ randn(3,2,2) catch e println(e) end
DimensionMismatch("batch size mismatch: A != B")
我正在尝试在 Julia 中将 N 维 (N>=3) 数组乘以矩阵批次,即沿最后两个维度执行矩阵乘法,同时保持其他维度不变。
例如,如果 x
的维度为 (d1,d2,4,3)
,而 y
的维度为 (d1,d2,3,2)
,则乘法的结果应为 (d1,d2,4,2)
,即应执行一批矩阵乘法。
这正是 Python 中发生的事情 numpy.matmul
:
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.
np.matmul(randn(10,10,4,3), randn(10,10,3,2)).shape
(10, 10, 4, 2)
有没有办法在 Julia 中重现 numpy.matmul
的行为?
我希望 .*
会起作用,但是:
julia> randn(10,10,4,3) .* randn(10,10,3,2)
ERROR: DimensionMismatch("arrays could not be broadcast to a common size")
Stacktrace:
[1] _bcs1 at ./broadcast.jl:485 [inlined]
[2] _bcs at ./broadcast.jl:479 [inlined] (repeats 3 times)
[3] broadcast_shape at ./broadcast.jl:473 [inlined]
[4] combine_axes at ./broadcast.jl:468 [inlined]
[5] instantiate at ./broadcast.jl:256 [inlined]
[6] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{4},Nothing,typeof(*),Tuple{Array{Float64,4},Array{Float64,4}}}) at ./broadcast.jl:798
[7] top-level scope at REPL[80]:1
我知道列表推导式可能适用于 3-D,但这在更高维度上会变得非常混乱。重塑(或查看)除最后 2 个维度之外的所有维度、使用列表理解并将其重塑回来的最佳解决方案是什么?或者有更好的方法吗?
P.S。我能找到的最接近的是 this,但它并不完全相同。 Julia 的新手,因此可能会遗漏一些对 Julia 用户来说很明显的东西。
我不知道有任何这样的功能,但在某些程序包中很可能有。我认为在 Julia 中,将数据组织为矩阵数组并在其上广播矩阵乘法更为自然:
D = [rand(50, 60) for i in 1:4, j in 1:3]
E = [rand(60, 70) for i in 1:4, j in 1:3]
D .* E # now you can use dot broadcasting!
话虽如此,制作您自己的产品很容易。不过,我会做出一点改变。 Julia 是主要的列,而 numpy 是 "last dimension major",因此你应该让矩阵位于 第一个 两个维度,而不是最后两个
首先,我将定义一个乘以数组 C
的就地方法,然后定义一个调用就地版本的非就地方法(我将跳过维度检查等):
# In-place version, note the use of the @views macro,
# which is essential to get in-place behaviour
using LinearAlgebra: mul! # fast in-place matrix multiply
function batchmul!(C, A, B)
for j in axes(A, 4), i in axes(A, 3)
@views mul!(C[:, :, i, j], A[:, :, i, j], B[:, :, i, j])
end
return C
end
# The non-in-place version
function batchmul(A, B)
T = promote_type(eltype(A), eltype(B))
C = Array{T}(undef, size(A, 1), size(B)[2:end]...)
return batchmul!(C, A, B)
end
你也可以让它成为多线程的。 在我的计算机上,4 个线程提供了 2.5 倍的加速(实际上,对于最后两个维度的较大值,我获得了 3.5 倍的加速)你获得多少加速取决于大小和所涉及阵列的形状:
function batchmul!(C, A, B)
Threads.@threads for j in axes(A, 4)
for i in axes(A, 3)
@views mul!(C[:, :, i, j], A[:, :, i, j], B[:, :, i, j])
end
end
return C
end
编辑: 我刚才注意到您需要一般 N-D,而不仅仅是 4-D。不过,不应该太难概括。无论如何,更有理由选择矩阵数组,其中广播将自动适用于所有维度。
Edit2: 不能离开它,所以这是 N-D 情况下的一个(还有更多工作要做,比如处理非基于 1 的索引(更新:axes
应该解决这个问题)):
function batchmul!(C, A, B)
Threads.@threads for I in CartesianIndices(axes(A)[3:end])
@views mul!(C[:, :, Tuple(I)...], A[:, :, Tuple(I)...], B[:, :, Tuple(I)...])
end
return C
end
对于 N=3,您正在寻找 NNlib.batched_mul
。请注意(如上所述)Julia 的数组是列优先的,因此将最后一个索引而不是第一个索引视为批处理中的 运行 通常是有意义的:
julia> using NNlib
julia> randn(4,3,100) ⊠ randn(3,2,100)
4×2×100 Array{Float64, 3}:
[:, :, 1] =
0.9292 -0.223521
...
这只是一个类似batchmul!(C, A, B)
的循环,但它也会为GPU CuArray
调用适当的库函数。
将其扩展到超过 3 个维度并不难,但必须有人去做,并决定规则。对于第 3 个维度,它的行为类似于广播:
julia> randn(4,3,100) ⊠ randn(3,2) |> size
(4, 2, 100)
julia> randn(4,3,100) ⊠ randn(3,2,1) |> size
(4, 2, 100)
julia> try randn(4,3,100) ⊠ randn(3,2,2) catch e println(e) end
DimensionMismatch("batch size mismatch: A != B")