如何对复杂数据类型进行自动微分?
How to do automatic differentiation on complex datatypes?
给定一个基于向量的非常简单的矩阵定义:
import Numeric.AD
import qualified Data.Vector as V
newtype Mat a = Mat { unMat :: V.Vector a }
scale' f = Mat . V.map (*f) . unMat
add' a b = Mat $ V.zipWith (+) (unMat a) (unMat b)
sub' a b = Mat $ V.zipWith (-) (unMat a) (unMat b)
mul' a b = Mat $ V.zipWith (*) (unMat a) (unMat b)
pow' a e = Mat $ V.map (^e) (unMat a)
sumElems' :: Num a => Mat a -> a
sumElems' = V.sum . unMat
(出于演示目的...我正在使用 hmatrix,但我认为问题出在某种程度上)
和一个误差函数(eq3
):
eq1' :: Num a => [a] -> [Mat a] -> Mat a
eq1' as φs = foldl1 add' $ zipWith scale' as φs
eq3' :: Num a => Mat a -> [a] -> [Mat a] -> a
eq3' img as φs = negate $ sumElems' (errImg `pow'` (2::Int))
where errImg = img `sub'` (eq1' as φs)
为什么编译器无法在其中推断出正确的类型?
diffTest :: forall a . (Fractional a, Ord a) => Mat a -> [Mat a] -> [a] -> [[a]]
diffTest m φs as0 = gradientDescent go as0
where go xs = eq3' m xs φs
确切的错误信息是这样的:
src/Stuff.hs:59:37:
Could not deduce (a ~ Numeric.AD.Internal.Reverse.Reverse s a)
from the context (Fractional a, Ord a)
bound by the type signature for
diffTest :: (Fractional a, Ord a) =>
Mat a -> [Mat a] -> [a] -> [[a]]
at src/Stuff.hs:58:13-69
or from (reflection-1.5.1.2:Data.Reflection.Reifies
s Numeric.AD.Internal.Reverse.Tape)
bound by a type expected by the context:
reflection-1.5.1.2:Data.Reflection.Reifies
s Numeric.AD.Internal.Reverse.Tape =>
[Numeric.AD.Internal.Reverse.Reverse s a]
-> Numeric.AD.Internal.Reverse.Reverse s a
at src/Stuff.hs:59:21-42
‘a’ is a rigid type variable bound by
the type signature for
diffTest :: (Fractional a, Ord a) =>
Mat a -> [Mat a] -> [a] -> [[a]]
at src//Stuff.hs:58:13
Expected type: [Numeric.AD.Internal.Reverse.Reverse s a]
-> Numeric.AD.Internal.Reverse.Reverse s a
Actual type: [a] -> a
Relevant bindings include
go :: [a] -> a (bound at src/Stuff.hs:60:9)
as0 :: [a] (bound at src/Stuff.hs:59:15)
φs :: [Mat a] (bound at src/Stuff.hs:59:12)
m :: Mat a (bound at src/Stuff.hs:59:10)
diffTest :: Mat a -> [Mat a] -> [a] -> [[a]]
(bound at src/Stuff.hs:59:1)
In the first argument of ‘gradientDescent’, namely ‘go’
In the expression: gradientDescent go as0
gradientDescent
function from ad
的类型为
gradientDescent :: (Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) ->
f a -> [f a]
它的第一个参数需要一个 f r -> r
类型的函数,其中 r
是 forall s. (Reverse s a)
。 go
的类型为 [a] -> a
,其中 a
是 diffTest
签名中绑定的类型。这些 a
相同,但 Reverse s a
与 a
不同。
Reverse
类型有许多类型 classes 的实例,可以让我们将 a
转换为 Reverse s a
或返回。最明显的是 Fractional a => Fractional (Reverse s a)
,它允许我们使用 realToFrac
.
将 a
s 转换为 Reverse s a
s
为此,我们需要能够将函数 a -> b
映射到 Mat a
上以获得 Mat b
。最简单的方法是为 Mat
.
派生一个 Functor
实例
{-# LANGUAGE DeriveFunctor #-}
newtype Mat a = Mat { unMat :: V.Vector a }
deriving Functor
我们可以将 m
和 fs
转换为 Fractional a' => Mat a'
和 fmap realToFrac
。
diffTest m fs as0 = gradientDescent go as0
where go xs = eq3' (fmap realToFrac m) xs (fmap (fmap realToFrac) fs)
但是有一个更好的方法隐藏在广告包中。 Reverse s a
对所有 s
具有普遍限定,但 a
与 diffTest
的类型签名中绑定的 a
相同。我们真的只需要一个函数a -> (forall s. Reverse s a)
。此函数是 Mode
class 中的 auto
,Reverse s a
有一个实例。 auto
有稍微奇怪的类型 Mode t => Scalar t -> t
但 type Scalar (Reverse s a) = a
。专用于 Reverse
auto
具有类型
auto :: (Reifies s Tape, Num a) => a -> Reverse s a
这使我们能够将 Mat a
转换为 Mat (Reverse s a)
,而无需在 Rational
和 Rational
之间进行转换。
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
diffTest :: forall a . (Fractional a, Ord a) => Mat a -> [Mat a] -> [a] -> [[a]]
diffTest m fs as0 = gradientDescent go as0
where
go :: forall t. (Scalar t ~ a, Mode t) => [t] -> t
go xs = eq3' (fmap auto m) xs (fmap (fmap auto) fs)
给定一个基于向量的非常简单的矩阵定义:
import Numeric.AD
import qualified Data.Vector as V
newtype Mat a = Mat { unMat :: V.Vector a }
scale' f = Mat . V.map (*f) . unMat
add' a b = Mat $ V.zipWith (+) (unMat a) (unMat b)
sub' a b = Mat $ V.zipWith (-) (unMat a) (unMat b)
mul' a b = Mat $ V.zipWith (*) (unMat a) (unMat b)
pow' a e = Mat $ V.map (^e) (unMat a)
sumElems' :: Num a => Mat a -> a
sumElems' = V.sum . unMat
(出于演示目的...我正在使用 hmatrix,但我认为问题出在某种程度上)
和一个误差函数(eq3
):
eq1' :: Num a => [a] -> [Mat a] -> Mat a
eq1' as φs = foldl1 add' $ zipWith scale' as φs
eq3' :: Num a => Mat a -> [a] -> [Mat a] -> a
eq3' img as φs = negate $ sumElems' (errImg `pow'` (2::Int))
where errImg = img `sub'` (eq1' as φs)
为什么编译器无法在其中推断出正确的类型?
diffTest :: forall a . (Fractional a, Ord a) => Mat a -> [Mat a] -> [a] -> [[a]]
diffTest m φs as0 = gradientDescent go as0
where go xs = eq3' m xs φs
确切的错误信息是这样的:
src/Stuff.hs:59:37:
Could not deduce (a ~ Numeric.AD.Internal.Reverse.Reverse s a)
from the context (Fractional a, Ord a)
bound by the type signature for
diffTest :: (Fractional a, Ord a) =>
Mat a -> [Mat a] -> [a] -> [[a]]
at src/Stuff.hs:58:13-69
or from (reflection-1.5.1.2:Data.Reflection.Reifies
s Numeric.AD.Internal.Reverse.Tape)
bound by a type expected by the context:
reflection-1.5.1.2:Data.Reflection.Reifies
s Numeric.AD.Internal.Reverse.Tape =>
[Numeric.AD.Internal.Reverse.Reverse s a]
-> Numeric.AD.Internal.Reverse.Reverse s a
at src/Stuff.hs:59:21-42
‘a’ is a rigid type variable bound by
the type signature for
diffTest :: (Fractional a, Ord a) =>
Mat a -> [Mat a] -> [a] -> [[a]]
at src//Stuff.hs:58:13
Expected type: [Numeric.AD.Internal.Reverse.Reverse s a]
-> Numeric.AD.Internal.Reverse.Reverse s a
Actual type: [a] -> a
Relevant bindings include
go :: [a] -> a (bound at src/Stuff.hs:60:9)
as0 :: [a] (bound at src/Stuff.hs:59:15)
φs :: [Mat a] (bound at src/Stuff.hs:59:12)
m :: Mat a (bound at src/Stuff.hs:59:10)
diffTest :: Mat a -> [Mat a] -> [a] -> [[a]]
(bound at src/Stuff.hs:59:1)
In the first argument of ‘gradientDescent’, namely ‘go’
In the expression: gradientDescent go as0
gradientDescent
function from ad
的类型为
gradientDescent :: (Traversable f, Fractional a, Ord a) =>
(forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) ->
f a -> [f a]
它的第一个参数需要一个 f r -> r
类型的函数,其中 r
是 forall s. (Reverse s a)
。 go
的类型为 [a] -> a
,其中 a
是 diffTest
签名中绑定的类型。这些 a
相同,但 Reverse s a
与 a
不同。
Reverse
类型有许多类型 classes 的实例,可以让我们将 a
转换为 Reverse s a
或返回。最明显的是 Fractional a => Fractional (Reverse s a)
,它允许我们使用 realToFrac
.
a
s 转换为 Reverse s a
s
为此,我们需要能够将函数 a -> b
映射到 Mat a
上以获得 Mat b
。最简单的方法是为 Mat
.
Functor
实例
{-# LANGUAGE DeriveFunctor #-}
newtype Mat a = Mat { unMat :: V.Vector a }
deriving Functor
我们可以将 m
和 fs
转换为 Fractional a' => Mat a'
和 fmap realToFrac
。
diffTest m fs as0 = gradientDescent go as0
where go xs = eq3' (fmap realToFrac m) xs (fmap (fmap realToFrac) fs)
但是有一个更好的方法隐藏在广告包中。 Reverse s a
对所有 s
具有普遍限定,但 a
与 diffTest
的类型签名中绑定的 a
相同。我们真的只需要一个函数a -> (forall s. Reverse s a)
。此函数是 Mode
class 中的 auto
,Reverse s a
有一个实例。 auto
有稍微奇怪的类型 Mode t => Scalar t -> t
但 type Scalar (Reverse s a) = a
。专用于 Reverse
auto
具有类型
auto :: (Reifies s Tape, Num a) => a -> Reverse s a
这使我们能够将 Mat a
转换为 Mat (Reverse s a)
,而无需在 Rational
和 Rational
之间进行转换。
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
diffTest :: forall a . (Fractional a, Ord a) => Mat a -> [Mat a] -> [a] -> [[a]]
diffTest m fs as0 = gradientDescent go as0
where
go :: forall t. (Scalar t ~ a, Mode t) => [t] -> t
go xs = eq3' (fmap auto m) xs (fmap (fmap auto) fs)