在 System.Numerics.Vector<T> 中使用 F# 度量单位

Using F# Units of Measure with System.Numerics.Vector<T>

我很难将 F# 度量单位与 System.Numerics.Vector<'T> 类型结合使用。让我们来看一个玩具问题:假设我们有一个类型为 float<m>[] 的数组 xs,并且出于某种原因我们想对其所有分量求平方,从而得到一个类型为 float<m^2>[] 的数组。这与标量代码完美配合:

xs |> Array.map (fun x -> x * x) // float<m^2>[]

现在假设我们想通过使用 SIMD 在 System.Numerics.Vector<float>.Count 大小的块中执行乘法来向量化此操作,例如像这样:

open System.Numerics
let simdWidth = Vector<float>.Count
// fill with dummy data
let xs = Array.init (simdWidth * 10) (fun i -> float i * 1.0<m>)
// array to store the results
let rs: float<m^2> array = Array.zeroCreate (xs |> Array.length)
// number of SIMD operations required
let chunks = (xs |> Array.length) / simdWidth
// for simplicity, assume xs.Length % simdWidth = 0
for i = 0 to chunks - 1 do
    let v = Vector<_>(xs, i * simdWidth) // Vector<float<m>>, containing xs.[i .. i+simdWidth-1]
    let u = v * v                        // Vector<float<m>>; expected: Vector<float<m^2>>
    u.CopyTo(rs, i * simdWidth)          // units mismatch

我相信我理解 为什么 会发生这种情况:F# 编译器如何知道 System.Numerics.Vector<'T>.op_Multiply 做什么以及应用什么算术规则?它实际上可以是任何操作。那么它应该如何推断出正确的单位呢?

问题是:完成这项工作的最佳方法是什么?我们如何告诉编译器适用哪些规则?

尝试 1:从 xs 中删除所有度量单位信息,稍后再添加回来:

// remove all UoM from all arrays
let xsWoM = Array.map (fun x -> x / 1.0<m>) xs
// ...
// perform computation using xsWoM etc.
// ...
// add back units again
let xs = Array.map (fun x -> x * 1.0<m>) xsWoM

问题:执行不必要的计算 and/or 复制操作,出于性能原因而无法达到矢量化代码的目的。此外,在很大程度上违背了使用 UoM 的初衷。

尝试2:使用内联IL改变Vector<'T>.op_Multiply的return类型:

// reinterpret x to be of type 'b
let inline retype (x: 'a) : 'b = (# "" x: 'b #)
let inline (.*.) (u: Vector<float<'m>>) (v: Vector<float<'m>>): Vector<float<'m^2>> = u * v |> retype
// ...
let u = v .*. v // asserts type Vector<float<m^2>>

问题:不需要任何额外的操作,但使用了一个已弃用的功能(内联 IL)并且不是完全通用的(仅在度量单位方面)。

有没有人对此有更好的解决方案*

*请注意,上面的示例确实是一个用来演示一般问题的玩具问题。真实程序解决了一个复杂得多的初值问题,涉及到多种物理量

编译器可以很好地弄清楚如何应用乘法的单位规则,这里的问题是你有一个包装类型。在您的第一个示例中,当您编写 xs |> Array.map (fun x -> x * x) 时,您是根据数组的元素而不是直接在数组上描述乘法。

当你有一个 Vector<float<m>> 时,单位附加到 float 而不是 Vector 所以当你尝试乘法 Vector 时,编译器赢了' 将该类型视为具有任何单位。

鉴于 class 公开的方法,我认为直接使用 Vector<'T> 没有简单的解决方法,但有包装类型的选项。

这样的东西可以给你一个 unit-friendly 向量:

type VectorWithUnits<'a, [<Measure>]'b> = 
    |VectorWithUnits of Vector<'a>

    static member inline (*) (a : VectorWithUnits<'a0, 'b0>, b : VectorWithUnits<'a0, 'b1>) 
        : VectorWithUnits<'a0, 'b0*'b1> =
        match a, b with
        |VectorWithUnits a, VectorWithUnits b -> VectorWithUnits <| a * b

在这种情况下,单位附加到向量并且乘法向量按预期工作并具有正确的单位行为。

问题是现在我们可以在 Vector<'T>float 本身上使用单独且不同的度量单位注释。

您可以使用以下方法将具有度量单位的特定类型的数组转换为一组 Vectors:

let toFloatVectors (array : float<'m>[]) : VectorWithUnits<float,'m>[]  =
    let arrs = array |> Array.chunkBySize (Vector<float>.Count)
    arrs |> Array.map (Array.map (float) >> Vector >> VectorWithUnits)

返回:

let fromFloatVectors (vectors : VectorWithUnits<float,'m>[]) : float<'m>[] =
    let arr = Array.zeroCreate<float> (Array.length vectors)
    vectors |> Array.iteri (fun i uVec ->
        match uVec with
        |VectorWithUnits vec -> vec.CopyTo arr)
    arr |> Array.map (LanguagePrimitives.FloatWithMeasure<'m>)

一个骇人听闻的选择:

如果您放弃通用类型 'T,您可以通过一些相当糟糕的装箱和运行时转换使 float Vector 正常运行。这滥用了一个事实,即度量单位是一个编译时构造,在运行时不再存在。

type FloatVectorWithUnits<[<Measure>]'b> = 
    |FloatVectorWithUnits of Vector<float<'b>>

    static member ( * ) (a : FloatVectorWithUnits<'b0>, b : FloatVectorWithUnits<'b1>) =
        match a, b with
        |FloatVectorWithUnits a, FloatVectorWithUnits b ->
            let c, d = box a :?> Vector<float<'b0*'b1>>, box b :?> Vector<float<'b0*'b1>>
            c * d |> FloatVectorWithUnits

我想出了一个解决方案,它满足了我的大部分要求(看起来)。它的灵感来自 (包装 Vector<'T>),但也为底层数组数据类型添加了一个包装器(称为 ScalarField)。这样我们可以跟踪单位,而在下面我们只处理原始数据并且可以使用 not unit-aware System.Numerics.Vector API。

一个简化的、bare-bones 快速而肮脏的实现看起来像这样:

// units-aware wrapper for System.Numerics.Vector<'T>
type PackedScalars<[<Measure>] 'm> = struct
    val public Data: Vector<float>
    new (d: Vector<float>) = {Data = d}
    static member inline (*) (u: PackedScalars<'m1>, v: PackedScalars<'m2>) = u.Data * v.Data |> PackedScalars<'m1*'m2>
end

// unit-ware type, wrapping a raw array for easy stream processing
type ScalarField<[<Measure>] 'm> = struct
    val public Data: float[]
    member self.Item with inline get i                = LanguagePrimitives.FloatWithMeasure<'m> self.Data.[i]
                     and  inline set i (v: float<'m>) = self.Data.[i] <- (float v)
    member self.Packed 
           with inline get i                        = Vector<float>(self.Data, i) |> PackedScalars<'m>
           and  inline set i (v: PackedScalars<'m>) = v.Data.CopyTo(self.Data, i)
    new (d: float[]) = {Data = d}
    new (count: int) = {Data = Array.zeroCreate count}
end

我们现在可以使用两种数据结构以相对优雅、高效的方式解决示例问题:

let xs = Array.init (simdWidth * 10) float |> ScalarField<m>    
let mutable rs = Array.zeroCreate (xs.Data |> Array.length) |> ScalarField<m^2>
let chunks = (xs.Data |> Array.length) / simdWidth
for i = 0 to chunks - 1 do
    let j = i * simdWidth
    let v = xs.Packed(j) // PackedScalars<m>
    let u = v * v        // PackedScalars<m^2>
    rs.Packed(j) <- u

最重要的是,re-implement unit-aware ScalarField 包装器的常规数组操作可能很有用,例如

[<CompilationRepresentation(CompilationRepresentationFlags.ModuleSuffix)>]
module ScalarField =
    let map f (sf: ScalarField<_>) =
        let mutable res = Array.zeroCreate sf.Data.Length |> ScalarField
        for i = 0 to sf.Data.Length do
           res.[i] <- f sf.[i]
        res

等等

缺点:对于基础数字类型 (float) 而言不是通用的,因为 floatWithMeasure 没有通用替代品。为了使其通用,我们必须实现第三个包装器 Scalar,它还包装了底层原语:

type Scalar<'a, [<Measure>] 'm> = struct
    val public Data: 'a
    new (d: 'a) = {Data = d}
end

type PackedScalars<'a, [<Measure>] 'm 
            when 'a: (new: unit -> 'a) 
            and  'a: struct 
            and  'a :> System.ValueType> = struct
    val public Data: Vector<'a>
    new (d: Vector<'a>) = {Data = d}
    static member inline (*) (u: PackedScalars<'a, 'm1>, v: PackedScalars<'a, 'm2>) = u.Data * v.Data |> PackedScalars<'a, 'm1*'m2>
end

type ScalarField<'a, [<Measure>] 'm
            when 'a: (new: unit -> 'a) 
            and  'a: struct 
            and  'a :> System.ValueType> = struct
    val public Data: 'a[]
    member self.Item with inline get i                    = Scalar<'a, 'm>(self.Data.[i])
                     and  inline set i (v: Scalar<'a,'m>) = self.Data.[i] <- v.Data
    member self.Packed 
           with inline get i                          = Vector<'a>(self.Data, i) |> PackedScalars<_,'m>
           and  inline set i (v: PackedScalars<_,'m>) = v.Data.CopyTo(self.Data, i)
    new (d:'a[]) = {Data = d}
    new (count: int) = {Data = Array.zeroCreate count}
end

... 这意味着我们基本上不是通过使用 float<'m> 之类的 "refinement" 类型来跟踪单位,而是仅通过具有辅助 type/units 参数的包装类型。

不过,我仍然希望有人能提出更好的主意。 :)