如何对复杂数据类型进行自动区分



给定一个基于向量的非常简单的矩阵定义:

import Numeric.AD
import qualified Data.Vector as V
newtype Mat a = Mat { unMat :: V.Vector a }
scale' f = Mat . V.map (*f) . unMat
add' a b = Mat $ V.zipWith (+) (unMat a) (unMat b)
sub' a b = Mat $ V.zipWith (-) (unMat a) (unMat b)
mul' a b = Mat $ V.zipWith (*) (unMat a) (unMat b)
pow' a e = Mat $ V.map (^e) (unMat a)
sumElems' :: Num a => Mat a -> a
sumElems' = V.sum . unMat

(用于演示目的...我正在使用hmatrix,但认为问题以某种方式存在(

和一个误差函数(eq3(:

eq1' :: Num a => [a] -> [Mat a] -> Mat a
eq1' as φs = foldl1 add' $ zipWith scale' as φs
eq3' :: Num a => Mat a -> [a] -> [Mat a] -> a
eq3' img as φs = negate $ sumElems' (errImg `pow'` (2::Int))
  where errImg = img `sub'` (eq1' as φs)

为什么编译器无法推断出正确的类型?

diffTest :: forall a . (Fractional a, Ord a) => Mat a -> [Mat a] -> [a] -> [[a]]
diffTest m φs as0 = gradientDescent go as0
  where go xs = eq3' m xs φs

确切的错误消息是这样的:

src/Stuff.hs:59:37:
    Could not deduce (a ~ Numeric.AD.Internal.Reverse.Reverse s a)
    from the context (Fractional a, Ord a)
      bound by the type signature for
                 diffTest :: (Fractional a, Ord a) =>
                             Mat a -> [Mat a] -> [a] -> [[a]]
      at src/Stuff.hs:58:13-69
    or from (reflection-1.5.1.2:Data.Reflection.Reifies
               s Numeric.AD.Internal.Reverse.Tape)
      bound by a type expected by the context:
                 reflection-1.5.1.2:Data.Reflection.Reifies
                   s Numeric.AD.Internal.Reverse.Tape =>
                 [Numeric.AD.Internal.Reverse.Reverse s a]
                 -> Numeric.AD.Internal.Reverse.Reverse s a
      at src/Stuff.hs:59:21-42
      ‘a’ is a rigid type variable bound by
          the type signature for
            diffTest :: (Fractional a, Ord a) =>
                        Mat a -> [Mat a] -> [a] -> [[a]]
          at src//Stuff.hs:58:13
    Expected type: [Numeric.AD.Internal.Reverse.Reverse s a]
                   -> Numeric.AD.Internal.Reverse.Reverse s a
      Actual type: [a] -> a
    Relevant bindings include
      go :: [a] -> a (bound at src/Stuff.hs:60:9)
      as0 :: [a] (bound at src/Stuff.hs:59:15)
      φs :: [Mat a] (bound at src/Stuff.hs:59:12)
      m :: Mat a (bound at src/Stuff.hs:59:10)
      diffTest :: Mat a -> [Mat a] -> [a] -> [[a]]
        (bound at src/Stuff.hs:59:1)
    In the first argument of ‘gradientDescent’, namely ‘go’
    In the expression: gradientDescent go as0

ad 中的 gradientDescent 函数的类型为

gradientDescent :: (Traversable f, Fractional a, Ord a) =>
                   (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) ->
                   f a -> [f a]

它的第一个参数需要一个类型为 f r -> r 的函数,其中 r forall s. (Reverse s a)go具有类型 [a] -> a其中a是绑定在 diffTest 签名中的类型。这些a是相同的,但Reverse s aa不同。

Reverse 类型具有许多类型类的实例,这些实例允许我们将a转换为Reverse s a或转换为 。最明显的是Fractional a => Fractional (Reverse s a),它将使我们能够使用 realToFraca s 转换为 Reverse s a s 。

为此,我们需要能够在Mat a上映射a -> b函数以获得Mat b。执行此操作的最简单方法是为 Mat 派生一个Functor实例。

{-# LANGUAGE DeriveFunctor #-}
newtype Mat a = Mat { unMat :: V.Vector a }
    deriving Functor

我们可以将mfs转换为任何Fractional a' => Mat a' fmap realToFrac .

diffTest m fs as0 = gradientDescent go as0
  where go xs = eq3' (fmap realToFrac m) xs (fmap (fmap realToFrac) fs)

但是有一种更好的方法隐藏在广告包中。Reverse s a在所有s上都是通用限定的,但adiffTest的类型签名中绑定的a相同。我们真的只需要一个函数a -> (forall s. Reverse s a) .此函数autoMode 类,Reverse s a 具有该类的实例。 auto有有点奇怪的类型Mode t => Scalar t -> ttype Scalar (Reverse s a) = a.专门用于Reverse auto具有的类型

auto :: (Reifies s Tape, Num a) => a -> Reverse s a

这使我们能够将Mat a转换为Mat (Reverse s a),而不会弄乱与Rational之间的转换。

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
diffTest :: forall a . (Fractional a, Ord a) => Mat a -> [Mat a] -> [a] -> [[a]]
diffTest m fs as0 = gradientDescent go as0
  where
    go :: forall t. (Scalar t ~ a, Mode t) => [t] -> t
    go xs = eq3' (fmap auto m) xs (fmap (fmap auto) fs)

相关内容

  • 没有找到相关文章

最新更新