AD 反射 - 它是如何工作的?

AD Reflection - How does it work?

我看过 ad 程序包,我了解它如何通过提供 class Floating 的不同实例然后实施衍生规则来自动微分。

但是在例子中

Prelude Debug.SimpleReflect Numeric.AD> diff atanh x
recip (1 - x * x) * 1

我们看到它可以将 函数 表示为 AST 并将它们显示为带有变量名的字符串。

我想知道他们是怎么做到的,因为当我写的时候:

f :: Floating a => a -> a
f x = x^2

不管我提供什么实例,我都会得到一个函数 f :: Something -> Something 而不是像 f :: ASTf :: String

这样的表示

实例不能"know"参数是什么

他们是怎么做到的?

其实和AD包没有关系,和diff atanh x中的x有关。

为了看到这个,让我们定义我们自己的 AST 类型

data AST = AST :+ AST
         | AST :* AST
         | AST :- AST
         | Negate AST
         | Abs AST
         | Signum AST
         | FromInteger Integer
         | Variable String

我们可以为这个类型定义一个Num实例

instance Num (AST) where
  (+) = (:+)
  (*) = (:*)
  (-) = (:-)
  negate = Negate
  abs = Abs
  signum = Signum
  fromInteger = FromInteger

还有一个Show实例

instance Show (AST) where
  showsPrec p (a :+ b) = showParen (p > 6) (showsPrec 6 a . showString " + " . showsPrec 6 b)
  showsPrec p (a :* b) = showParen (p > 7) (showsPrec 7 a . showString " * " . showsPrec 7 b)
  showsPrec p (a :- b) = showParen (p > 6) (showsPrec 6 a . showString " - " . showsPrec 7 b)
  showsPrec p (Negate a) = showParen (p >= 10) (showString "negate " . showsPrec 10 a)
  showsPrec p (Abs a) = showParen (p >= 10) (showString "abs " . showsPrec 10 a)
  showsPrec p (Signum a) = showParen (p >= 10) (showString "signum " . showsPrec 10 a)
  showsPrec p (FromInteger n) = showsPrec p n
  showsPrec _ (Variable v) = showString v

所以现在如果我们定义一个函数:

f :: Num a => a -> a
f a = a ^ 2

和一个 AST 变量:

x :: AST
x = Variable "x"

我们可以 运行 生成整数值或 AST 值的函数:

λ f 5
25
λ f x
x * x

如果我们希望能够将我们的 AST 类型与您的函数 f :: Floating a => a -> a; f x = x^2 一起使用,我们需要扩展其定义以允许我们实现 Floating (AST).