Julia 中的标量和数组的按元素就地操作是否有统一的语法?

Is there a unified syntax for element-wise in-place operations on scalars and arrays in Julia?

考虑以下累加器类型,它像数组一样工作,您可以将内容推送给它,但仅跟踪其均值:

mutable struct Accumulator{T}
    data::T
    count::Int64
end

function Base.push!(acc::Accumulator, term)
    acc.data += term       # <-- in-place addition
    acc.count += 1
    acc
end

mean(acc::Accumulator) = acc.data ./ acc.count

我希望它适用于 T 作为标量或数组类型。然而, 事实证明,对于 T 是一个数组类型, push! 中的添加创建了一个临时的。这是因为在 Julia 中,x+=aequivalentx=x+a,我怀疑 Julia 不能保证 acc.dataterm 不混淆。

一个简单的修复方法是将 += 替换为逐元素加法 .+=。但是,这将破坏不允许这样做的标量类型。所以我想出解决这个问题的唯一方法是添加以下形式的特化:

function Base.push!(acc::Accumulator, term::AbstractArray)
    acc.data .+= term       # <-- element-wise addition
    acc.count += 1
    acc
end

然而,这有点丑陋而且也很脆弱...有没有人知道更好的方法,最好是以通用的方式并且没有临时创建?

奇怪的是,Numbers are iterable in Julia,但这似乎对我们没有帮助,因为 Numbers 没有 setindex! 方法。

这里有两种不同的方法。第一个使用 iterator traits,第二个只是稍微修补方法签名以解决极端情况。

迭代器特征

我们可以使用 IteratorSize 特性来区分标量和向量。对于标量,Base.IteratorSize(x) returns Base.HasShape{0}。对于数组,Base.IteratorSize(x) returns Base.HasShape{N},其中N是数组的维数。

mutable struct Accumulator{T}
    data::T
    count::Int64
end

function Base.push!(acc::Accumulator{T}, term::S) where {T, S}
    _push_acc!(Base.IteratorSize(T), Base.IteratorSize(S), acc, term)
end

function _push_acc!(::Base.HasShape{0}, ::Base.HasShape{0}, acc::Accumulator, term)
    acc.data += term
    acc.count += 1
    acc
end

function _push_acc!(::Base.HasShape{N}, ::Base.HasShape{N}, acc::Accumulator, term) where {N}
    acc.data .+= term
    acc.count += 1
    acc
end

function _push_acc!(::Base.HasShape{M}, ::Base.HasShape{N}, ::Accumulator, ::Any) where {M, N}
    throw(ArgumentError("Accumulator and term have inconsistent shapes"))
end

在 REPL 上的行动:

julia> a = Accumulator(1, 0)
Accumulator{Int64}(1, 0)

julia> b = Accumulator([1, 2], 0)
Accumulator{Array{Int64,1}}([1, 2], 0)

julia> push!(a, 42)
Accumulator{Int64}(43, 1)

julia> push!(b, [3, 4])
Accumulator{Array{Int64,1}}([4, 6], 1)

julia> push!(a, [5, 6])
ERROR: ArgumentError: Accumulator and term have inconsistent shapes
Stacktrace:
 [1] _push_acc!(::Base.HasShape{0}, ::Base.HasShape{1}, ::Accumulator{Int64}, ::Array{Int64,1}) at ...
 [2] push!(::Accumulator{Int64}, ::Array{Int64,1}) at ...
 [3] top-level scope at REPL[6]:1

julia> push!(b, 10)
ERROR: ArgumentError: Accumulator and term have inconsistent shapes
Stacktrace:
 [1] _push_acc!(::Base.HasShape{1}, ::Base.HasShape{0}, ::Accumulator{Array{Int64,1}}, ::Int64) at ...
 [2] push!(::Accumulator{Array{Int64,1}}, ::Int64) at ...
 [3] top-level scope at REPL[7]:1

修补方法签名

我们可以不使用迭代器特征,而是对您的 push! 方法签名进行一些小调整,以防止将数组推入标量。

mutable struct Accumulator{T}
    data::T
    count::Int64
end

function Base.push!(acc::Accumulator, term)
    acc.data += term
    acc.count += 1
    acc
end

function Base.push!(acc::Accumulator{T}, term::AbstractArray) where {T <: AbstractArray}
    acc.data .+= term
    acc.count += 1
    acc
end

function Base.push!(::Accumulator, ::AbstractArray)
    throw(ArgumentError("Can't push an array onto a scalar"))
end

现在,如果我们尝试将数组推入标量,我们会收到一条合理的错误消息:

julia> a = Accumulator(42, 0)
Accumulator{Int64}(42, 0)

julia> push!(a, [1, 2])
ERROR: ArgumentError: Can't push an array onto a scalar