AD反射 - 它是如何工作的



我看过ad包,我了解它如何通过提供不同的class Floating实例然后实现导数规则来进行自动微分。

但在示例中

Prelude Debug.SimpleReflect Numeric.AD> diff atanh x
recip (1 - x * x) * 1

我们看到它可以将函数表示为 AST s,并将它们显示为带有变量名称的字符串。

我想知道他们是怎么做到的,因为当我写:

f :: Floating a => a -> a
f x = x^2

无论我提供什么实例,我都会得到一个函数f :: Something -> Something而不是像f :: ASTf :: String这样的表示

实例无法"知道"参数是什么。

他们怎么能做到?

实际上,它与AD包无关,而与diff atanh x中的x有关。

为了看到这一点,让我们定义我们自己的 AST 类型

data AST = AST :+ AST
         | AST :* AST
         | AST :- AST
         | Negate AST
         | Abs AST
         | Signum AST
         | FromInteger Integer
         | Variable String

我们可以为这种类型定义一个Num实例

instance Num (AST) where
  (+) = (:+)
  (*) = (:*)
  (-) = (:-)
  negate = Negate
  abs = Abs
  signum = Signum
  fromInteger = FromInteger

和一个Show实例

instance Show (AST) where
  showsPrec p (a :+ b) = showParen (p > 6) (showsPrec 6 a . showString " + " . showsPrec 6 b)
  showsPrec p (a :* b) = showParen (p > 7) (showsPrec 7 a . showString " * " . showsPrec 7 b)
  showsPrec p (a :- b) = showParen (p > 6) (showsPrec 6 a . showString " - " . showsPrec 7 b)
  showsPrec p (Negate a) = showParen (p >= 10) (showString "negate " . showsPrec 10 a)
  showsPrec p (Abs a) = showParen (p >= 10) (showString "abs " . showsPrec 10 a)
  showsPrec p (Signum a) = showParen (p >= 10) (showString "signum " . showsPrec 10 a)
  showsPrec p (FromInteger n) = showsPrec p n
  showsPrec _ (Variable v) = showString v

所以现在如果我们定义一个函数:

f :: Num a => a -> a
f a = a ^ 2

和一个 AST 变量:

x :: AST
x = Variable "x"

我们可以运行函数来生成整数值或 AST 值:

λ f 5
25
λ f x
x * x

如果我们希望能够将我们的 AST 类型与您的函数f :: Floating a => a -> a; f x = x^2一起使用,我们需要扩展其定义以允许我们实现Floating (AST)

相关内容

  • 没有找到相关文章

最新更新