在循环或广播中保留 Array 的元素类型联合以避免动态调度

Preserving Array's element-type union in loops or broadcasts to avoid dynamic dispatch

我想对 Array{Union{A,B,C}, M} 进行元素运算,并且我想要 Array{Union{A,B,C}, N} 的输出,而没有因无法推断 Union{A,B,C} 元素类型而导致的动态调度。

我在下面粘贴了我正在玩的例子。我知道我可以使用类型注释来控制 return 类型,例如local y::eltype(x),但它不会停止元素操作中的 Any 推理(如 @code_warntype 所示)或动态调度(不会影响分配数量,我稍后会报告)。

在示例中,联合拆分似乎最多可处理 3 个具体类型,但循环拆分元素类型时,广播拆分 Vector 类型而不是元素类型参数。我正在寻求通用的方法来编写循环或元素方法来绕过这些推理限制,即使它变得像每种类型的条件语句一样乏味。

begin
  function loop(x)
    local y
    for i in x
      y = i
    end
    return y
  end
        
  function each1(x)
    return x+1
  end
  
  # inputs to test
  const x3 = Union{Int, Float64, Bool}[1, 1.1, true]
  const x4 = Union{Int, Float64, Bool, ComplexF64}[1, 1.1, true, 1.1im]
end

我不会粘贴整个 @code_warntype 打印输出,只粘贴推断的输出类型。此外,我不确定如何衡量动态调度,但它会导致额外的分配,所以我在 @time.

的第二次运行中报告分配情况
typeof(loop(x3)) # Bool
@code_warntype loop(x3) # Union{Bool, Float64, Int64}
@time loop(x3) # 0 allocations

typeof(each1.(x3)) # Vector{Real}
@code_warntype each1.(x3) # Union{Vector{Float64}, Vector{Int64}, Vector{Real}}
@time each1.(x3) # 3 allocations

typeof(loop(x4)) # ComplexF64
@code_warntype loop(x4) # Any
@time loop(x4) # 8 allocations

typeof(each1.(x4)) # Vector{Number}
@code_warntype each1.(x4) # AbstractVector{<:Number}
@time each1.(x4) # 13 allocations

你看到它的双重原因:

  1. 如您所见,如果并集大于 3 种类型,Julia 编译器将放弃尝试进行精确类型推断并使用更粗略的边界。这是一个硬编码假设,以避免过多的代码生成。
  2. 如果可能,广播会缩小集合的 eltype

让我用一个最小的例子来展示第二个问题:

julia> x = Any[1]
1-element Vector{Any}:
 1

julia> identity.(x)
1-element Vector{Int64}:
 1

还有当你写的时候:

broadcasting splits the Vector type instead of the element type parameter

一个准确的信息是广播告诉你结果将是Vector{Float64},或Vector{Int64},或Vector{Real},这取决于数组的实际内容(因为eltype 执行缩小)。

避免它的方法是改用广播赋值(或只是一个循环)。但是,您必须确保结果容器的 eltype 能够存储操作结果。所以在你的情况下它可能是例如:

julia> r3 = similar(x3, Union{Int, Float64})
3-element Vector{Union{Float64, Int64}}:
 0.0
 0.0
 0.0

julia> r3 .= each1.(x3)
3-element Vector{Union{Float64, Int64}}:
 2
 2.1
 2

julia> @time r3 .= each1.(x3)
  0.000034 seconds (1 allocation: 16 bytes)
3-element Vector{Union{Float64, Int64}}:
 2
 2.1
 2

(请注意,我已经手动选择了 r3eltype 以便它是正确的 - 即我知道根据 each1 的定义我可能期望得到什么类型的结果以及 x3)

的类型