使文字常量的类型依赖于其他变量

Make type of literal constant depend on other variables

我在 Julia 中有以下代码,其中文字常量 2. 对数组元素进行乘法运算。我现在将文字常量设置为单精度 (2.f0),但我想让类型取决于其他变量(这些变量要么都是 Float64,要么都是 Float32)。我该如何以优雅的方式做到这一点?

function diff!(
        at, a,
        visc, dxidxi, dyidyi, dzidzi,
        itot, jtot, ktot)
​
    @tturbo for k in 2:ktot-1
        for j in 2:jtot-1
            for i in 2:itot-1
                at[i, j, k] += visc * (
                    (a[i-1, j  , k  ] - 2.f0 * a[i, j, k] + a[i+1, j  , k  ]) * dxidxi +
                    (a[i  , j-1, k  ] - 2.f0 * a[i, j, k] + a[i  , j+1, k  ]) * dyidyi +
                    (a[i  , j  , k-1] - 2.f0 * a[i, j, k] + a[i  , j  , k+1]) * dzidzi )
            end
        end
    end
end

一般来说,如果你有一个标量x或者一个数组A,你可以分别用T = typeof(x)或者T = eltype(A)获取类型,然后使用将文字转换为等效类型,例如

julia> A = [1.0]
1-element Vector{Float64}:
 1.0

julia> T = eltype(A)
Float64

julia> T(2)
2.0

所以你原则上可以在函数中使用它,如果一切都是类型稳定的,这个应该实际上是无开销的:

julia> @code_native 2 * 1.0f0
    .section    __TEXT,__text,regular,pure_instructions
; ┌ @ promotion.jl:322 within `*'
; │┌ @ promotion.jl:292 within `promote'
; ││┌ @ promotion.jl:269 within `_promote'
; │││┌ @ number.jl:7 within `convert'
; ││││┌ @ float.jl:94 within `Float32'
    vcvtsi2ss   %rdi, %xmm1, %xmm1
; │└└└└
; │ @ promotion.jl:322 within `*' @ float.jl:331
    vmulss  %xmm0, %xmm1, %xmm0
; │ @ promotion.jl:322 within `*'
    retq
    nopw    (%rax,%rax)
; └

julia> @code_native 2.0f0 * 1.0f0
    .section    __TEXT,__text,regular,pure_instructions
; ┌ @ float.jl:331 within `*'
    vmulss  %xmm1, %xmm0, %xmm0
    retq
    nopw    %cs:(%rax,%rax)
; └

julia> @code_native Float32(2) * 1.0f0
    .section    __TEXT,__text,regular,pure_instructions
; ┌ @ float.jl:331 within `*'
    vmulss  %xmm1, %xmm0, %xmm0
    retq
    nopw    %cs:(%rax,%rax)
; └

然而,碰巧的是,在 Julia 中有一个更优雅的模式来编写函数签名,这样它将参数化地专注于您传递给此函数的数组的元素类型,然后您应该能够在没有开销的情况下使用以确保您的文字是适当的类型,如下所示:

function diff!(at::AbstractArray{T}, a::AbstractArray{T},
        visc, dxidxi, dyidyi, dzidzi,
        itot, jtot, ktot) where T <: Number

    @tturbo for k in 2:ktot-1
        for j in 2:jtot-1
            for i in 2:itot-1
                at[i, j, k] += visc * (
                    (a[i-1, j  , k  ] - T(2) * a[i, j, k] + a[i+1, j  , k  ]) * dxidxi +
                    (a[i  , j-1, k  ] - T(2) * a[i, j, k] + a[i  , j+1, k  ]) * dyidyi +
                    (a[i  , j  , k-1] - T(2) * a[i, j, k] + a[i  , j  , k+1]) * dzidzi )
            end
        end
    end
end

Julia

中关于 parametric methods 的文档在某种程度上讨论了这种方法

Base中有一个不错的小功能:

help?> oftype
search: oftype

  oftype(x, y)

  Convert y to the type of x (convert(typeof(x), y)).

  Examples
  ≡≡≡≡≡≡≡≡≡≡

  julia> x = 4;
  
  julia> y = 3.;
  
  julia> oftype(x, y)
  3
  
  julia> oftype(y, x)
  4.0

所以你可以使用像

这样的东西
two = oftype(at[i,j,k], 2)

在适当的地方。

对于一次多个变量,你可以这样写

two, visc, dxidxi, dyidyi, dzidzi = convert.(T, 2, visc, dxidxi, dyidyi, dzidzi)

在顶部(使用 T @cbk 的回答中的类型参数),因为 oftype(x, y) = convert(typeof(x), y).