Haskell 数据类型定义取决于 GADT 和函数输出

Haskell data type definition depended on GADTs and function output

我想要张量数据结构

data Nat where
    Zero :: Nat
    Succ :: Nat -> Nat

-- | A list of type a and of length n
data ListN a (dim :: Nat) where
    Nil  :: ListN a Zero
    Cons :: a -> ListN a n -> ListN a (Succ n)    

data Tensor a where
        Dense :: ListN a n -> ListN Int Nat -> Tensor a

张量由元素列表和表示张量维度的整数列表表示。例如,ListN 中的 [3,4,5,6] 表示您有 4 个维度,每个维度分别为 3、4、5 和 6 个元素长。 但是现在我希望第一个 ListN 中的 n 取决于存储在第二个 ListN 中的所有整数的乘积,因为这是我在第一个 ListN 中可以拥有的元素数量。 但是我应该怎么做呢?

为此,您需要 Tensor 类型的 type-level 维度向量,而不仅仅是 ListN Int Nat 值,因此最好定义 Tensor 带有 dims 类型参数。您可能还会发现将尺寸放在第一位,将元素类型放在第二位更方便,例如:

data ListN (dim :: Nat) a where
    Nil  :: ListN Zero a
    Cons :: a -> ListN n a -> ListN (Succ n) a
infixr 5 `Cons`

data Tensor (dims :: [Nat]) a where
  Dense :: (Product dims ~ n) => ListN n a -> Tensor dims a

这里缺少的部分是 Product,这是一个 type-level 函数来乘以维度。对于多个 Peano naturals 来说有点乏味,但是下面的作品:

type family Plus m n where
  Plus (Succ m) n = Plus m (Succ n)
  Plus Zero n = n

type family Times m n where
  Times (Succ m) n = Plus n (Times m n)
  Times Zero n = Zero

type family Product (dims) where
  Product '[] = Succ Zero
  Product (m : ns) = Times m (Product ns)

之后,进行以下类型检查。请注意,我在上面制作了 Cons 一个 infixr 运算符以避免大量括号:

t1 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
t1 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` 6 `Cons` Nil)

如果元素个数错误,则约束失败,所以下面不进行类型检查:

t2 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
t2 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` Nil)

完整示例:

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

data Nat where
    Zero :: Nat
    Succ :: Nat -> Nat

data ListN (dim :: Nat) a where
    Nil  :: ListN Zero a
    Cons :: a -> ListN n a -> ListN (Succ n) a
infixr 5 `Cons`

data Tensor (dims :: [Nat]) a where
  Dense :: (Product dims ~ n) => ListN n a -> Tensor dims a

type family Plus m n where
  Plus (Succ m) n = Plus m (Succ n)
  Plus Zero n = n

type family Times m n where
  Times (Succ m) n = Plus n (Times m n)
  Times Zero n = Zero

type family Product (dims) where
  Product '[] = Succ Zero
  Product (m : ns) = Times m (Product ns)

-- type checks
t1 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
t1 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` 6 `Cons` Nil)

-- won't type check
t2 :: Tensor '[Succ Zero, Succ (Succ Zero), Succ (Succ (Succ Zero))] Int
t2 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` Nil)

如评论中所述,有一个内置的 non-Peano Nat 类型,您 可能 会发现它更易于使用。重写为使用它,代码将如下所示:

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

import GHC.TypeLits

data ListN (dim :: Nat) a where
    Nil  :: ListN 0 a
    Cons :: a -> ListN n a -> ListN (1 + n) a
infixr 5 `Cons`

data Tensor (dims :: [Nat]) a where
  Dense :: (Product dims ~ n) => ListN n a -> Tensor dims a

type family Product dims where
  Product '[] = 1
  Product (m : ns) = m * Product ns

-- type checks
t1 :: Tensor '[1,2,3] Int
t1 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` 6 `Cons` Nil)

-- won't type check
t2 :: Tensor '[1,2,3] Int
t2 = Dense (1 `Cons` 2 `Cons` 3 `Cons` 4 `Cons` 5 `Cons` Nil)