haskell中的foldl(具有多级foldl)问题



foldr版本比foldl版本快:

文件夹版本:

cartProdN9 :: [[a]] -> [[a]]
cartProdN9 xss = 
foldr h1 [[]] xss where 
h1 xs yss = foldr g [] xs where
g x zss = foldr f zss yss where 
f xs yss = (x:xs):yss 

折页版

cartProdN11 :: [[a]] -> [[a]]
cartProdN11 xss = 
foldl h1 [[]] xss where 
h1 yss xs = foldl g [] xs where
g zss x = foldl f zss yss where 
f yss xs = (x:xs):yss 

进程cartProdN9 [[1,2]| i <- [1 .. 1000]]正常。但是cartProdN11 [[1,2]| i <- [1 .. 1000]]不好。

严格的版本fold'仍然不好:

foldl' f z []     = z
foldl' f z (x:xs) = let z' = z `f` x 
in  z' `seq` foldl' f z' xs

甚至在https://www.fpcomplete.com/haskell/tutorial/all-about-strictness/

{-# LANGUAGE BangPatterns #-}
module D where   
data StrictList a = Cons !a !(StrictList a) | Nil
strictMap :: (a -> b) -> StrictList a -> StrictList b
strictMap _ Nil = Nil
strictMap f (Cons a list) =
let !b = f a
!list' = strictMap f list
in b `seq` list' `seq` Cons b list'
strictEnum :: Int -> Int -> StrictList Int
strictEnum low high =
go low
where
go !x
| x == high = Cons x Nil
| otherwise = Cons x (go $! x + 1)
list  :: Int -> StrictList Int
list !x = Cons x (Cons x Nil)
foldlS = f z l ->
case l of
Nil -> z
Cons !x !xs -> let !z' = z `f` x
in  z' `seq` foldlS f z' xs  
listlist :: StrictList (StrictList Int)
listlist = strictMap list $! strictEnum 1 10
cartProdN12 :: StrictList (StrictList a) -> StrictList (StrictList a)
cartProdN12 xss =
foldlS h1 (Cons Nil Nil) xss where
h1 !yss !xs = foldlS g Nil xs where
g !zss !x = foldlS f zss yss where
f !yss !xs = Cons (Cons x xs ) yss
myhead  :: StrictList a ->  a
myhead =  l ->
case l of
Cons x xs -> x

r = cartProdN12 listlist
hr :: Int
hr =  myhead( myhead r)

CCD_ 4仍然太慢而无法计算。

所以我的问题是:如何使foldl版本的计算速度与foldr版本一样快?有可能吗?

流程cartProdN9[[1,2]|i<-[1..1000]]可以。

我对此深表怀疑,因为生成的列表将有2^1000个元素,所以您可能没有正确地进行基准测试。

下面是我拼凑的一个小基准,它表明简单严格的版本实际上更快:

module Main where
import Test.Tasty.Bench
cartProdN9 :: [[a]] -> [[a]]
cartProdN9 xss = 
foldr h1 [[]] xss where 
h1 xs yss = foldr g [] xs where
g x zss = foldr f zss yss where 
f xs yss = (x:xs):yss 
cartProdN11 :: [[a]] -> [[a]]
cartProdN11 xss = 
foldl h1 [[]] xss where 
h1 yss xs = foldl g [] xs where
g zss x = foldl f zss yss where 
f yss xs = (x:xs):yss 
mkBench :: ([[Int]] -> [[Int]]) -> Int -> Benchmark
mkBench f n = bench (show n) $ nf f (replicate n [1, 2])
main :: IO ()
main = defaultMain
[ bgroup "cartProdN9"  $ map (mkBench cartProdN9) [10,15,20]
, bgroup "cartProdN11" $ map (mkBench cartProdN11) [10,15,20]
]

结果:

All
cartProdN9
10: OK (0.16s)
36.7 μs ± 3.0 μs
15: OK (0.29s)
4.48 ms ± 273 μs
20: OK (5.75s)
378  ms ±  28 ms
cartProdN11
10: OK (0.28s)
33.1 μs ± 2.2 μs
15: OK (0.98s)
3.76 ms ± 292 μs
20: OK (5.22s)
337  ms ±  12 ms

mkBench函数中的nf非常重要,如果使用whnf,则会得到非常不同的结果:

All
cartProdN9
10: OK (0.14s)
122  ns ±  11 ns
15: OK (0.19s)
189  ns ±  11 ns
20: OK (0.27s)
257  ns ±  11 ns
cartProdN11
10: OK (0.18s)
10.7 μs ± 683 ns
15: OK (0.30s)
2.41 ms ± 150 μs
20: OK (0.56s)
188  ms ± 4.2 ms

最新更新