使用列表优雅地实现n维矩阵乘法



列表函数允许我们非常优雅地实现任意维度的向量数学。例如:

on   = (.) . (.)
add  = zipWith (+)
sub  = zipWith (-)
mul  = zipWith (*)
dist = len `on` sub
dot  = sum `on` mul
len  = sqrt . join dot

等等。

main = print $ add [1,2,3] [1,1,1] -- [2,3,4]
main = print $ len [1,1,1]         -- 1.7320508075688772
main = print $ dot [2,0,0] [2,0,0] -- 4

当然,这不是最有效的解决方案,但很有见地,可以说mapzipWith等概括了这些向量运算。不过,有一个功能我无法优雅地实现,那就是交叉乘积。既然叉积的一个可能的n维推广是第nd矩阵行列式,我如何优雅地实现矩阵乘法

编辑:是的,我问了一个与我设置的问题完全无关的问题。Fml。

碰巧我有一些代码可以用来做n维矩阵运算,至少在我写的时候我觉得很可爱:

{-# LANGUAGE NoMonomorphismRestriction #-}
module MultiArray where
import Control.Arrow
import Control.Monad
import Data.Ix
import Data.Maybe
import Data.Array (Array)
import qualified Data.Array as A
-- {{{ from Dmwit.hs
deleteAt n   xs = take n xs ++ drop (n + 1) xs
insertAt n x xs = take n xs ++ x : drop n xs
doublify f g xs ys = f (uncurry g) (zip xs ys)
any2 = doublify any
all2 = doublify all
-- }}}
-- makes the most sense when ls and hs have the same length
instance Ix a => Ix [a] where
    range     = sequence . map range . uncurry zip
    inRange   = all2 inRange . uncurry zip
    rangeSize = product . uncurry (zipWith (curry rangeSize))
    index (ls, hs) xs = fst . foldr step (0, 1) $ zip indices sizes where
        indices = zipWith index (zip ls hs) xs
        sizes   = map rangeSize $ zip ls hs
        step (i, b) (s, p) = (s + p * i, p * b)
fold :: (Enum i, Ix i) => ([a] -> b) -> Int -> Array [i] a -> Array [i] b
fold f n a = A.array newBound assocs where
    (oldLowBound, oldHighBound) = A.bounds a
    (newLowBoundBeg , dimLow : newLowBoundEnd ) = splitAt n oldLowBound
    (newHighBoundBeg, dimHigh: newHighBoundEnd) = splitAt n oldHighBound
    assocs   = [(beg ++ end, f [a A.! (beg ++ i : end) | i <- [dimLow..dimHigh]])
               | beg <- range (newLowBoundBeg, newHighBoundBeg)
               , end <- range (newLowBoundEnd, newHighBoundEnd)
               ]
    newBound = (newLowBoundBeg ++ newLowBoundEnd, newHighBoundBeg ++ newHighBoundEnd)
flatten a = check a >> return value where
    check = guard . (1==) . length . fst . A.bounds
    value = A.ixmap ((head *** head) . A.bounds $ a) return a
elementWise :: (MonadPlus m, Ix i) => (a -> b -> c) -> Array i a -> Array i b -> m (Array i c)
elementWise f a b = check >> return value where
    check = guard $ A.bounds a == A.bounds b
    value = A.listArray (A.bounds a) (zipWith f (A.elems a) (A.elems b))
unsafeFlatten       a   = fromJust $ flatten       a
unsafeElementWise f a b = fromJust $ elementWise f a b
matrixMult a b = fold sum 1 $ unsafeElementWise (*) a' b' where
    aBounds = (join (***) (!!0)) $ A.bounds a
    bBounds = (join (***) (!!1)) $ A.bounds b
    a' = copy 2 bBounds a
    b' = copy 0 aBounds b
bijection f g a = A.ixmap ((f *** f) . A.bounds $ a) g a
unFlatten       = bijection return head
matrixTranspose = bijection reverse reverse
copy n (low, high) a = A.ixmap (newBounds a) (deleteAt n) a where
    newBounds = (insertAt n low *** insertAt n high) . A.bounds

这里最可爱的部分是matrixMult,它是唯一专门用于二维数组的操作之一。它沿着一个维度扩展它的第一个自变量(通过将二维对象的副本放入三维对象的每个切片中(;沿着另一个扩展其第二个;进行逐点乘法(现在是在三维数组中(;然后通过求和来折叠所制造的第三维。很不错。

相关内容

  • 没有找到相关文章

最新更新