最小 Numeric.AD 示例无法编译

minimal Numeric.AD example won't compile

我正在尝试从 Numeric.AD 编译以下最小示例:

import Numeric.AD 
timeAndGrad f l = grad f l
main = putStrLn "hi"

我 运行 遇到这个错误:

test.hs:3:24:
    Couldn't match expected type ‘f (Numeric.AD.Internal.Reverse.Reverse
                                       s a)
                                  -> Numeric.AD.Internal.Reverse.Reverse s a’
                with actual type ‘t’
      because type variable ‘s’ would escape its scope
    This (rigid, skolem) type variable is bound by
      a type expected by the context:
        Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape =>
        f (Numeric.AD.Internal.Reverse.Reverse s a)
        -> Numeric.AD.Internal.Reverse.Reverse s a
      at test.hs:3:19-26
    Relevant bindings include
      l :: f a (bound at test.hs:3:15)
      f :: t (bound at test.hs:3:13)
      timeAndGrad :: t -> f a -> f a (bound at test.hs:3:1)
    In the first argument of ‘grad’, namely ‘f’
    In the expression: grad f l

关于为什么会发生这种情况的任何线索?通过查看前面的示例,我了解到这是 "flattening" grad 的类型:

grad :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f a

但我实际上需要在我的代码中做这样的事情。事实上,这是无法编译的最小示例。我想做的更复杂的事情是这样的:

example :: SomeType
example f x args = (do stuff with the gradient and gradient "function")
    where gradient = grad f x
          gradientFn = grad f
          (other where clauses involving gradient and gradient "function")

这是一个稍微复杂一些的版本,它带有可以编译的类型签名。

{-# LANGUAGE RankNTypes #-}

import Numeric.AD 
import Numeric.AD.Internal.Reverse

-- compiles but I can't figure out how to use it in code
grad2 :: (Show a, Num a, Floating a) => (forall s.[Reverse s a] -> Reverse s a) -> [a] -> [a]
grad2 f l = grad f l

-- compiles with the right type, but the resulting gradient is all 0s...
grad2' :: (Show a, Num a, Floating a) => ([a] -> a) -> [a] -> [a]
grad2' f l = grad f' l
       where f' = Lift . f . extractAll
       -- i've tried using the Reverse constructor with Reverse 0 _, Reverse 1 _, and Reverse 2 _, but those don't yield the correct gradient. Not sure how the modes work

extractAll :: [Reverse t a] -> [a]
extractAll xs = map extract xs
           where extract (Lift x) = x -- non-exhaustive pattern match

dist :: (Show a, Num a, Floating a) => [a] -> a
dist [x, y] = sqrt(x^2 + y^2)

-- incorrect output: [0.0, 0.0]
main = putStrLn $ show $ grad2' dist [1,2]

但是,我不知道如何在代码中使用第一个版本 grad2,因为我不知道如何处理 Reverse s a。第二个版本 grad2' 具有正确的类型,因为我使用内部构造函数 Lift 创建了一个 Reverse s a,但我一定不理解内部结构(特别是参数 s) 有效,因为输出梯度全为 0。使用其他构造函数 Reverse(此处未显示)也会产生错误的渐变。

或者,是否有 libraries/code 人们使用 ad 代码的示例?我认为我的用例很常见。

使用 where f' = Lift . f . extractAll,您实质上是在自动微分基础类型中创建了一个后门,该后门丢弃了所有导数,只保留常量值。如果您随后将其用于 grad,得到零结果也就不足为奇了!

明智的方法是按原样使用 grad

dist :: Floating a => [a] -> a
dist [x, y] = sqrt $ x^2 + y^2
-- preferrable is of course `dist = sqrt . sum . map (^2)`

main = print $ grad dist [1,2]
-- output: [0.4472135954999579,0.8944271909999159]

您实际上不需要知道任何更复杂的知识就可以使用自动微分。只要你只区分 NumFloating 多态函数,一切都会按原样工作。如果您需要区分作为参数传入的函数,则需要将该参数设为 rank-2 多态(另一种方法是切换到 ad 函数的 rank-1 版本,但我敢说不太优雅,并没有给你带来太多好处。

{-# LANGUAGE Rank2Types, UnicodeSyntax #-}

mainWith :: (∀n . Floating n => [n] -> n) -> IO ()
mainWith f = print $ grad f [1,2]

main = mainWith dist