如何将依赖大小的数组推广到 n 维?

How to generalize dependently sized arrays to n dimensions?

我已经研究这个有一段时间了,但我一直没能说服 GHC 完成这项工作。

基本上,在 Haskell/GHC 的当前版本中创建依赖大小的数组非常容易:

newtype Arr1 (w :: Nat) a = Arr1 (Int -> a)
newtype Arr2 (w :: Nat) (h :: Nat) a = Arr2 (Int -> a)

ix2 :: forall w h a. (KnownNat w) => Arr2 w h a -> Int -> Int -> a
ix2 (Arr2 f) x y = f ( y * w + x )
    where w = fromInteger $ natVal (Proxy :: Proxy w)

sub2 :: forall w h a. (KnownNat w) => Arr2 w h a -> Int -> Arr1 w a
sub2 (Arr2 f) y = Arr1 $ \x -> f (y * w + x)
    where w = fromInteger $ natVal (Proxy :: Proxy w)

mkArr2V :: forall w h a. (V.Unbox a, KnownNat w, KnownNat h) => V.Vector a -> Arr2 w h a
mkArr2V v = Arr2 $ (v V.!)

-- and so on ... errorchecking neglected

但是当前的 GHC 版本为我们提供了更多的表达能力。基本上应该可以为此创建一个类型:

newtype Mat (s :: [Nat]) a = Mat (Int -> a)

-- create array backed by vector
mkMatV :: forall s a. V.Vector a -> Mat s a
mkMatV v = Mat $ (v V.!)

这适用于 GHCi:

>>> let m = mkMatV (V.fromList [1,2,3,4]) :: Mat [2,2] Double
>>> :t m
m :: Mat '[2, 2] Double

但到目前为止,我不确定如何完成对数组的索引。一个简单的解决方案是对 nd 和 1d 索引使用两个不同的函数。请注意,这不是类型检查。

-- slice from nd array
(!) :: forall s ss a. (KnownNat s) => Mat (s ': ss) a -> Int -> Mat ss a
(!) (Mat f) o = Mat $ \i -> f (o*s+i)
    where s = fromInteger $ natVal (Proxy :: Proxy (sum ss))

-- index into 1d array
(#) :: forall s ss a. (KnownNat s) => Mat (s ': '[]) a -> Int -> a
(#) (Mat f) o = Mat $ \i -> f o

大概可以这样使用:

>>> :t m ! 0
Mat [2] Double
>>> m ! 0 # 0
1

并不是说必须按 z,y,x 顺序给出索引。我首选的解决方案将提供一个索引函数,该函数根据数组的维数更改其 return 类型。据我所知,这可以通过使用类型 类 以某种方式实现,但我还没有弄清楚。如果可以按 "natural" x,y,z 顺序给出索引,则加分。

tl;dr: 我要求一个函数,该函数对上面定义的 n 维数组进行索引。

这确实可以用 classes 类型来完成。一些预赛:

{-# LANGUAGE
  UndecidableInstances, MultiParamTypeClasses, TypeFamilies,
  ScopedTypeVariables, FunctionalDependencies, TypeOperators,
  DataKinds, FlexibleInstances #-}

import qualified Data.Vector as V

import GHC.TypeLits
import Data.Proxy

newtype NVec (shape :: [Nat]) a = NVec {_data :: V.Vector a}

首先,我们应该能够说出 n 维向量的整体平面大小。我们将使用它来计算索引的步幅。我们使用 class 在类型级列表上递归。

class FlatSize (sh :: [Nat]) where
  flatSize :: Proxy sh -> Int

instance FlatSize '[] where
  flatSize _ = 1

instance (KnownNat s, FlatSize ss) => FlatSize (s ': ss) where
  flatSize _ = fromIntegral (natVal (Proxy :: Proxy s)) * flatSize (Proxy :: Proxy ss)       

我们也使用类型 class 进行索引。我们为一维情况(我们简单地对基础向量进行索引)和高维情况(我们 return 一个新的 NVec 减少了维度)提供了不同的实例。不过,我们对这两种情况使用相同的 class。

infixl 5 !                                            
class Index (sh :: [Nat]) (a :: *) (b :: *) | sh a -> b where
  (!) :: NVec sh a -> Int -> b

instance Index '[s] a a where
  (NVec v) ! i = v V.! i         

instance (Index (s2 ': ss) a b, FlatSize (s2 ': ss), res ~ NVec (s2 ': ss) a) 
  => Index (s1 ': s2 ': ss) a res where
  (NVec v) ! i = NVec (V.slice (i * stride) stride v)
    where stride = flatSize (Proxy :: Proxy (s2 ': ss))

索引到一个更高维的向量只是用结果向量的平面大小和适当的偏移量取一个切片。

一些测试:

fromList :: forall a sh. FlatSize sh => [a] -> NVec sh a
fromList as | length as == flatSize (Proxy :: Proxy sh) = NVec (V.fromList as)
fromList _ = error "fromList: initializer list has wrong size"

v3 :: NVec [2, 2, 2] Int
v3 = fromList [
  2, 4,
  5, 6,

  10, 20,
  30, 0 ]

v2 :: NVec [2, 2] Int
v2 = v3 ! 0

vElem :: Int
vElem = v3 ! 0 ! 1 ! 1 -- 6 

另外,让我也提供一个 singletons 解决方案,因为它更加方便。它让我们可以重用更多的代码(更少的自定义类型 classes 用于单个函数)并以更直接的函数式风格编写。

{-# LANGUAGE
  UndecidableInstances, MultiParamTypeClasses, TypeFamilies,
  ScopedTypeVariables, FunctionalDependencies, TypeOperators,
  DataKinds, FlexibleInstances, StandaloneDeriving, DeriveFoldable,
  GADTs, FlexibleContexts #-}

import qualified Data.Vector as V
import qualified Data.Foldable as F
import GHC.TypeLits
import Data.Singletons.Preludeimport 
import Data.Singletons.TypeLits

newtype NVec (shape :: [Nat]) a = NVec {_data :: V.Vector a}

flatSize就变得简单多了:我们只要把sh降低到数值级别,照常操作就可以了:

flatSize :: Sing (sh :: [Nat]) -> Int
flatSize = fromIntegral . product . fromSing

我们使用类型族和函数进行索引。在之前的解决方案中,我们使用实例来调度维度;这里我们对模式匹配做同样的事情:

type family Index (shape :: [Nat]) (a :: *) where
  Index (s  ': '[])       a = a
  Index (s1 ':  s2 ': ss) a = NVec (s2 ': ss) a

infixl 5 !
(!) :: forall a sh. SingI sh => NVec sh a -> Int -> Index sh a
(!) (NVec v) i = case (sing :: Sing sh) of
  SCons _ SNil       -> v V.! i
  SCons _ ss@SCons{} -> NVec (V.slice (i * stride) stride v) where
    stride = flatSize ss

我们还可以使用 Nat 单例进行安全索引和初始化(即静态检查边界和大小)。对于初始化,我们定义了一个具有静态大小 (Vec) 的列表类型。

safeIx ::
  forall a s sh i. (SingI (s ': sh), (i + 1) <= s) =>
  NVec (s ': sh) a -> Sing i -> Index (s ': sh) a
safeIx v si = v ! (fromIntegral $ fromSing si)                    

data Vec n a where
  VNil :: Vec 0 a
  (:>) :: a -> Vec (n - 1) a -> Vec n a
infixr 5 :>
deriving instance F.Foldable (Vec n)

fromVec :: forall a sh. SingI sh => Vec (Foldr (:*$) 1 sh) a -> NVec sh a
fromVec = fromList . F.toList

安全功能的一些示例:

-- Other than 8 elements in the Vec would be a type error
v3 :: NVec [2, 2, 2] Int
v3 = fromVec
     (2 :> 4  :>
      5 :> 6  :>

      10 :> 20 :>
      30 :> 0  :> VNil)

vElem :: Int
vElem = v3
  `safeIx` (sing :: Sing 0)
  `safeIx` (sing :: Sing 1)
  `safeIx` (sing :: Sing 1) -- 6