如何在accelerate-haskell中定义矩阵乘积
How to define the matrix product in accelerate-haskell
我正在尝试在 accelerate 之上定义一个类型安全的矩阵计算库,部分是出于教育目的,部分是为了看看这是否是一种实用的方法。
但是当涉及到正确定义矩阵的乘积时我完全被卡住了 - 即在某种程度上 GHC accepts/compiles 我的代码。
我试过几次,都是这个的变体:
Linear.hs
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
import qualified Data.Array.Accelerate as A
import GHC.TypeLits
import Data.Array.Accelerate ( (:.)(..), Array
, Exp, Shape, FullShape, Slice
, DIM0, DIM1, DIM2, Z(Z)
, IsFloating, IsNum, Elt, Acc
, Any(Any), All(All))
import Data.Proxy
newtype Matrix (rows :: Nat) (cols :: Nat) a = AccMatrix {unMatrix :: Acc (Array DIM2 a)}
(#*#) :: forall k m n a. (KnownNat k, KnownNat m, KnownNat n, IsNum a, Elt a) =>
Matrix k m a -> Matrix m n a -> Matrix k n a
v #*# w = let v' = unMatrix v
w' = unMatrix w
in AccMatrix $ A.generate (A.index2 k' n') undefined
where k' = fromInteger $ natVal (Proxy :: Proxy k)
n' = fromInteger $ natVal (Proxy :: Proxy n)
aux :: Acc (Array (FullShape (Z :. Int) :. Int) e) -> Acc (Array (FullShape (Z :. All) :. Int) e) -> Exp ((Z :. Int) :. Int) -> Exp e
aux v w sh = let (Z:.i:.j) = A.unlift sh
v' = A.slice v (A.lift $ Z:.i:.All)
w' = A.slice w (A.lift $ Z:.All:.j)
in A.the $ A.sum $ A.zipWith (*) v' w'
错误stack build
给我的是
.../src/Linear.hs:196:55:
Couldn't match type ‘A.Plain ((Z :. head0) :. head1)’
with ‘(Z :. Int) :. Int’
The type variables ‘head0’, ‘head1’ are ambiguous
Expected type: Exp (A.Plain ((Z :. head0) :. head1))
Actual type: Exp ((Z :. Int) :. Int)
Relevant bindings include
i :: head0 (bound at src/Linear.hs:196:38)
j :: head1 (bound at src/Linear.hs:196:41)
In the first argument of ‘A.unlift’, namely ‘sh’
In the expression: A.unlift sh
.../src/Linear.hs:197:47:
Couldn't match type ‘FullShape (A.Plain (Z :. head0))’
with ‘Z :. Int’
The type variable ‘head0’ is ambiguous
Expected type: Acc
(Array (FullShape (A.Plain (Z :. head0) :. All)) e)
Actual type: Acc (Array (FullShape (Z :. Int) :. Int) e)
Relevant bindings include
v' :: Acc (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)
(bound at src/Linear.hs:197:34)
i :: head0 (bound at src/Linear.hs:196:38)
In the first argument of ‘A.slice’, namely ‘v’
In the expression: A.slice v (A.lift $ Z :. i :. All)
.../src/Linear.hs:198:39:
Couldn't match type ‘A.SliceShape (A.Plain ((Z :. All) :. head1))’
with ‘A.SliceShape (A.Plain (Z :. head0)) :. Int’
The type variables ‘head0’, ‘head1’ are ambiguous
Expected type: Acc
(Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)
Actual type: Acc
(Array (A.SliceShape (A.Plain ((Z :. All) :. head1))) e)
Relevant bindings include
w' :: Acc (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)
(bound at src/Linear.hs:198:34)
v' :: Acc (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)
(bound at src/Linear.hs:197:34)
i :: head0 (bound at src/Linear.hs:196:38)
j :: head1 (bound at src/Linear.hs:196:41)
In the expression: A.slice w (A.lift $ Z :. All :. j)
In an equation for ‘w'’: w' = A.slice w (A.lift $ Z :. All :. j)
.../src/Linear.hs:198:47:
Couldn't match type ‘FullShape (A.Plain ((Z :. All) :. head1))’
with ‘(Z :. Int) :. Int’
The type variable ‘head1’ is ambiguous
Expected type: Acc
(Array (FullShape (A.Plain ((Z :. All) :. head1))) e)
Actual type: Acc (Array (FullShape (Z :. All) :. Int) e)
Relevant bindings include
j :: head1 (bound at src/Linear.hs:196:41)
In the first argument of ‘A.slice’, namely ‘w’
In the expression: A.slice w (A.lift $ Z :. All :. j)
我查阅了 Accelerate, and I am also reading accelerate-arithmetic 的文档,该文档具有类似的目的,但没有使用 TypeLits
来断言 array/vector 维度。
我还尝试制作一个原始版本(即没有我自己的矩阵类型),以防我的类型错误,我认为这对 slice
的用法有同样的误解。为了完整起见,我将其包括在内,我可以添加错误消息,但我选择省略它们,因为我认为它们与上述问题无关。
(#*#) :: forall a. (IsNum a, Elt a) =>
Acc (Array DIM2 a) -> Acc (Array DIM2 a) -> Maybe (Acc (Array DIM2 a))
v #*# w = let Z:.k :.m = A.unlift $ A.arrayShape $ I.run v
Z:.m':.n = A.unlift $ A.arrayShape $ I.run w
in if m /= m'
then Nothing
else Just $ AccMatrix $ A.generate (A.index2 k n) (aux v w)
where aux :: Acc (Array DIM2 a) -> Acc (Array DIM2 a) -> Exp DIM2 -> Exp a
aux v w sh = let (Z:.i:.j) = A.unlift sh
v' = A.slice v (A.lift $ Z:.i:.All)
w' = A.slice w (A.lift $ Z:.All:.j)
in A.the $ A.sum $ A.zipWith (*) v' w'
您的代码实际上是正确的。不幸的是,类型检查器不够聪明,无法弄清楚,所以你必须帮助它:
let (Z:.i:.j) = A.unlift sh
变成
let (Z:.i:.j) = A.unlift sh :: (Z :. Exp Int) :. Exp Int
这里的关键是 A.unlift :: A.Unlift c e => c (A.Plain e) -> e
但 A.Plain
是关联类型族(因此是非单射的),因此没有类型签名就无法确定类型 e
,并且 e
需要 select 用于 Unlift c e
的实例。这就是 'ambiguous type' 错误的来源——实际上 e
是不明确的。
您还有一个不相关的错误。 aux
应具有类型
aux :: (IsNum e, Elt e) => ...
或
aux :: (e ~ a) => ...
在后一种情况下,a
是 (#*#)
类型签名中的一个,因此它已经具有约束条件 IsNum, Elt
我正在尝试在 accelerate 之上定义一个类型安全的矩阵计算库,部分是出于教育目的,部分是为了看看这是否是一种实用的方法。
但是当涉及到正确定义矩阵的乘积时我完全被卡住了 - 即在某种程度上 GHC accepts/compiles 我的代码。
我试过几次,都是这个的变体:
Linear.hs
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
import qualified Data.Array.Accelerate as A
import GHC.TypeLits
import Data.Array.Accelerate ( (:.)(..), Array
, Exp, Shape, FullShape, Slice
, DIM0, DIM1, DIM2, Z(Z)
, IsFloating, IsNum, Elt, Acc
, Any(Any), All(All))
import Data.Proxy
newtype Matrix (rows :: Nat) (cols :: Nat) a = AccMatrix {unMatrix :: Acc (Array DIM2 a)}
(#*#) :: forall k m n a. (KnownNat k, KnownNat m, KnownNat n, IsNum a, Elt a) =>
Matrix k m a -> Matrix m n a -> Matrix k n a
v #*# w = let v' = unMatrix v
w' = unMatrix w
in AccMatrix $ A.generate (A.index2 k' n') undefined
where k' = fromInteger $ natVal (Proxy :: Proxy k)
n' = fromInteger $ natVal (Proxy :: Proxy n)
aux :: Acc (Array (FullShape (Z :. Int) :. Int) e) -> Acc (Array (FullShape (Z :. All) :. Int) e) -> Exp ((Z :. Int) :. Int) -> Exp e
aux v w sh = let (Z:.i:.j) = A.unlift sh
v' = A.slice v (A.lift $ Z:.i:.All)
w' = A.slice w (A.lift $ Z:.All:.j)
in A.the $ A.sum $ A.zipWith (*) v' w'
错误stack build
给我的是
.../src/Linear.hs:196:55:
Couldn't match type ‘A.Plain ((Z :. head0) :. head1)’
with ‘(Z :. Int) :. Int’
The type variables ‘head0’, ‘head1’ are ambiguous
Expected type: Exp (A.Plain ((Z :. head0) :. head1))
Actual type: Exp ((Z :. Int) :. Int)
Relevant bindings include
i :: head0 (bound at src/Linear.hs:196:38)
j :: head1 (bound at src/Linear.hs:196:41)
In the first argument of ‘A.unlift’, namely ‘sh’
In the expression: A.unlift sh
.../src/Linear.hs:197:47:
Couldn't match type ‘FullShape (A.Plain (Z :. head0))’
with ‘Z :. Int’
The type variable ‘head0’ is ambiguous
Expected type: Acc
(Array (FullShape (A.Plain (Z :. head0) :. All)) e)
Actual type: Acc (Array (FullShape (Z :. Int) :. Int) e)
Relevant bindings include
v' :: Acc (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)
(bound at src/Linear.hs:197:34)
i :: head0 (bound at src/Linear.hs:196:38)
In the first argument of ‘A.slice’, namely ‘v’
In the expression: A.slice v (A.lift $ Z :. i :. All)
.../src/Linear.hs:198:39:
Couldn't match type ‘A.SliceShape (A.Plain ((Z :. All) :. head1))’
with ‘A.SliceShape (A.Plain (Z :. head0)) :. Int’
The type variables ‘head0’, ‘head1’ are ambiguous
Expected type: Acc
(Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)
Actual type: Acc
(Array (A.SliceShape (A.Plain ((Z :. All) :. head1))) e)
Relevant bindings include
w' :: Acc (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)
(bound at src/Linear.hs:198:34)
v' :: Acc (Array (A.SliceShape (A.Plain (Z :. head0)) :. Int) e)
(bound at src/Linear.hs:197:34)
i :: head0 (bound at src/Linear.hs:196:38)
j :: head1 (bound at src/Linear.hs:196:41)
In the expression: A.slice w (A.lift $ Z :. All :. j)
In an equation for ‘w'’: w' = A.slice w (A.lift $ Z :. All :. j)
.../src/Linear.hs:198:47:
Couldn't match type ‘FullShape (A.Plain ((Z :. All) :. head1))’
with ‘(Z :. Int) :. Int’
The type variable ‘head1’ is ambiguous
Expected type: Acc
(Array (FullShape (A.Plain ((Z :. All) :. head1))) e)
Actual type: Acc (Array (FullShape (Z :. All) :. Int) e)
Relevant bindings include
j :: head1 (bound at src/Linear.hs:196:41)
In the first argument of ‘A.slice’, namely ‘w’
In the expression: A.slice w (A.lift $ Z :. All :. j)
我查阅了 Accelerate, and I am also reading accelerate-arithmetic 的文档,该文档具有类似的目的,但没有使用 TypeLits
来断言 array/vector 维度。
我还尝试制作一个原始版本(即没有我自己的矩阵类型),以防我的类型错误,我认为这对 slice
的用法有同样的误解。为了完整起见,我将其包括在内,我可以添加错误消息,但我选择省略它们,因为我认为它们与上述问题无关。
(#*#) :: forall a. (IsNum a, Elt a) =>
Acc (Array DIM2 a) -> Acc (Array DIM2 a) -> Maybe (Acc (Array DIM2 a))
v #*# w = let Z:.k :.m = A.unlift $ A.arrayShape $ I.run v
Z:.m':.n = A.unlift $ A.arrayShape $ I.run w
in if m /= m'
then Nothing
else Just $ AccMatrix $ A.generate (A.index2 k n) (aux v w)
where aux :: Acc (Array DIM2 a) -> Acc (Array DIM2 a) -> Exp DIM2 -> Exp a
aux v w sh = let (Z:.i:.j) = A.unlift sh
v' = A.slice v (A.lift $ Z:.i:.All)
w' = A.slice w (A.lift $ Z:.All:.j)
in A.the $ A.sum $ A.zipWith (*) v' w'
您的代码实际上是正确的。不幸的是,类型检查器不够聪明,无法弄清楚,所以你必须帮助它:
let (Z:.i:.j) = A.unlift sh
变成
let (Z:.i:.j) = A.unlift sh :: (Z :. Exp Int) :. Exp Int
这里的关键是 A.unlift :: A.Unlift c e => c (A.Plain e) -> e
但 A.Plain
是关联类型族(因此是非单射的),因此没有类型签名就无法确定类型 e
,并且 e
需要 select 用于 Unlift c e
的实例。这就是 'ambiguous type' 错误的来源——实际上 e
是不明确的。
您还有一个不相关的错误。 aux
应具有类型
aux :: (IsNum e, Elt e) => ...
或
aux :: (e ~ a) => ...
在后一种情况下,a
是 (#*#)
类型签名中的一个,因此它已经具有约束条件 IsNum, Elt