(Julia) .+ 运算符似乎没有使用我的自定义广播功能,为什么?

(Julia) .+ operator does not seem to use my custom broadcast function, why?

我正在实现一个自定义矩阵,它只有一个非零值,无论您执行什么操作,这都是矩阵中唯一可以非零的单元格。我称它为 SVMatrix(单值矩阵)。我目前的代码是

struct SVMatrix{T} <: Base.AbstractMatrix{T}
    value::T
    index::Tuple{Int,Int}
    size::Tuple{Int,Int}
end

function Base.broadcast(+, A::SVMatrix, B::AbstractArray)
    SVMatrix(A.value+B[A.index...], A.index, A.size) 
end

function Base.getindex(A::SVMatrix{T}, i::Int) where {T}
    if i == A.index[1] + A.index[2]*A.size[1]
        A.value
    else
        0
    end
end


function Base.getindex(A::SVMatrix{T}, i::Vararg{Int,2}) where {T}
    if i == A.index
        return A.value
    else
        0
    end
end

function Base.size(A::SVMatrix)
    A.size
end

然后我按以下方式将广播功能与 .+ 运算符一起计时

function time(n::Int)
    A = SVMatrix(1.0, (3,4), (n, n))
    B = rand(n,n)
    @time broadcast(+, A, B)
    @time A .+ B
end

time(1000)
println()
time(1000)

得到结果

 0.000000 seconds
 0.008207 seconds (2 allocations: 7.629 MiB, 47.51% gc time)

 0.000000 seconds
 0.008258 seconds (2 allocations: 7.629 MiB)

所以 .+ 似乎没有使用我的自定义广播功能,即使它说 in the documentation

In fact, f.(args...) is equivalent to broadcast(f, args...), providing a convenient syntax to broadcast any function (dot syntax).

为什么我会得到这些结果?

这实际上是您不应该扩展广播的一个很好的例子。

julia> struct SVMatrix{T} <: Base.AbstractMatrix{T}
           value::T
           index::Tuple{Int,Int}
           size::Tuple{Int,Int}
       end

julia> @inline function Base.getindex(A::SVMatrix{T}, i::Vararg{Int,2}) where {T}
           @boundscheck checkbounds(A, i...)
           if i == A.index
               return A.value
           else
               return zero(T)
           end
       end

julia> Base.size(A::SVMatrix) = A.size

julia> SVMatrix(1.0, (1,1), (2, 2)) .+ ones(2, 2)
2×2 Array{Float64,2}:
 2.0  1.0
 1.0  1.0

.+的结果应该[2 0; 0 0]!如果我们使用您的广播实现(更正为在 ::typeof(+) 上分派为 ),当其他人使用它并期望它的行为与所有其他 AbstractArray 一样时,您的数组会出人意料地损坏。

现在,您 可以 return 巧妙地重新计算 SVMatrix 的操作是 .*:

julia> SVMatrix(2.5, (1,1), (2, 2)) .* ones(2, 2)
2×2 Array{Float64,2}:
 2.5  0.0
 0.0  0.0

我们可以在 O(1) space 和时间内完成此操作,但默认实现是遍历所有值并 return 密集 Array。这就是 Julia 的多重分派大放异彩的地方:

julia> Base.broadcasted(::typeof(*), A::SVMatrix, B::AbstractArray) = SVMatrix(A.value*B[A.index...], A.index, A.size)

julia> SVMatrix(2.5, (1,1), (2, 2)) .* ones(2, 2)
2×2 SVMatrix{Float64}:
 2.5  0.0
 0.0  0.0

由于这是一个复杂度为 O(1) 的操作并且是一个 巨大的胜利,我们可以选择退出广播融合并立即重新计算一个新的 SVMatrix — 即使在 "fused" 表达式中。不过,您还没有完成!

  • 需要对兼容的形状进行错误检查。
  • 需要允许广播诸如 SVMatrix(2.5, (1,1), (2, 2)) .* rand(2).
  • 之类的内容
  • 理想情况下,您将实现一个 BroadcastStyle 以允许在 "at least one SVMatrix is in the argument list." 上进行调度 然后您将实现 Base.broadcasted(::ArrayStyle{SVMatrix}, ::typeof(*), args...) 这将允许 SVMatrix 出现在任一侧.*,但这是一个高级主题。