类型级编程来表示多维数组(张量)

Type level programming to represent multidimensional arrays (Tensors)

我想要一个类型来以类型安全的方式表示多维数组(张量)。所以我可以这样写:zero :: Tensor (5,3,2) Integer 这将表示具有 5 个元素的多维数组,每个元素有 3 个元素,每个元素有 2 个元素,其中所有元素都是 Integers

您将如何使用类型级编程来定义此类型?

编辑:

在 Alec 的精彩回答之后,使用 GADTs 实现了这一点,

我想知道您是否可以更进一步,并支持 class Tensor 和张量操作以及张量序列化的多种实现

这样你可以有例如:

所有类型安全且易于使用。

我的意图是在 Haskell 中创建一个与 tensor-flow 非常相似但类型安全且可扩展性更强的库,使用 automatic differentiation (ad library), and exact real arithmetic (exact-real library)

我认为像 Haskell 这样的函数式语言比以某种方式萌芽的 python 生态系统更适合这些事情(在我看来对所有事情)。

虽然我看到了潜力,但我对这种类型级编程还不够精通(或不够聪明),所以我不知道如何在 Haskell 中实现这样的东西并且让它编译。

这就是我需要你帮助的地方。

这是一种方法 (here is a complete Gist)。我们坚持使用 Peano 数而不是 GHC 的类型级别 Nat 只是因为归纳法对它们效果更好。

{-# LANGUAGE GADTs, PolyKinds, DataKinds, TypeOperators, FlexibleInstances, FlexibleContexts #-}

import Data.Foldable
import Text.PrettyPrint.HughesPJClass

data Nat = Z | S Nat

-- Some type synonyms that simplify uses of 'Nat'
type N0 = Z
type N1 = S N0
type N2 = S N1
type N3 = S N2
type N4 = S N3
type N5 = S N4
type N6 = S N5
type N7 = S N6
type N8 = S N7
type N9 = S N8

-- Similar to lists, but indexed over their length
data Vector (dim :: Nat) a where
  Nil    :: Vector Z a
  (:-)   :: a -> Vector n a -> Vector (S n) a

infixr 5 :-

data Tensor (dim :: [Nat]) a where
  Scalar :: a -> Tensor '[] a
  Tensor :: Vector d (Tensor ds a) -> Tensor (d : ds) a

为了显示这些类型,我们将使用 pretty 包(GHC 已经附带)。

instance (Foldable (Vector n), Pretty a) => Pretty (Vector n a) where
  pPrint = braces . sep . punctuate (text ",") . map pPrint . toList

instance Pretty a => Pretty (Tensor '[] a) where
  pPrint (Scalar x) = pPrint x

instance (Pretty (Tensor ds a), Pretty a, Foldable (Vector d)) => Pretty (Tensor (d : ds) a) where
  pPrint (Tensor xs) = pPrint xs

然后这里是我们的数据类型的 Foldable 实例(这里并不奇怪 - 我包含它只是因为您需要它来编译 Pretty 实例):

instance Foldable (Vector Z) where
  foldMap f Nil = mempty

instance Foldable (Vector n) => Foldable (Vector (S n)) where
  foldMap f (x :- xs) = f x `mappend` foldMap f xs


instance Foldable (Tensor '[]) where
  foldMap f (Scalar x) = f x

instance (Foldable (Vector d), Foldable (Tensor ds)) => Foldable (Tensor (d : ds)) where
  foldMap f (Tensor xs) = foldMap (foldMap f) xs

最后,回答您问题的部分:我们可以定义 Applicative (Vector n)Applicative (Tensor ds),类似于 Applicative ZipList 的定义方式(除了 pure 没有 return 和空列表 - 它 return 是一个正确长度的列表)。

instance Applicative (Vector Z) where
  pure _ = Nil
  Nil <*> Nil = Nil

instance Applicative (Vector n) => Applicative (Vector (S n)) where
  pure x = x :- pure x
  (x :- xs) <*> (y :- ys) = x y :- (xs <*> ys)


instance Applicative (Tensor '[]) where
  pure = Scalar
  Scalar x <*> Scalar y = Scalar (x y)

instance (Applicative (Vector d), Applicative (Tensor ds)) => Applicative (Tensor (d : ds)) where
  pure x = Tensor (pure (pure x))
  Tensor xs <*> Tensor ys = Tensor ((<*>) <$> xs <*> ys)

然后,在 GHCi 中,使 zero 函数非常简单:

ghci> :set -XDataKinds
ghci> zero = pure 0
ghci> pPrint (zero :: Tensor [N5,N3,N2] Integer)
{{{0, 0}, {0, 0}, {0, 0}},
 {{0, 0}, {0, 0}, {0, 0}},
 {{0, 0}, {0, 0}, {0, 0}},
 {{0, 0}, {0, 0}, {0, 0}},
 {{0, 0}, {0, 0}, {0, 0}}}