在非整数键上高效地实现Memoization



我是Haskell的新手,一直在练习一些简单的编程挑战。在过去的两天里,我一直在尝试在这里实现无界背包问题。我使用的算法在维基百科页面上有描述,尽管对于这个问题,单词"weight"被替换为单词"length"。不管怎样,我一开始是在没有记忆的情况下编写代码:

maxValue :: [(Int,Int)] -> Int -> Int
maxValue [] len = 0
maxValue ((l, val): other) len =
if l > len then 
skipValue
else 
max skipValue takeValue
where skipValue = maxValue other len
takeValue = (val + maxValue ([(l, val)] ++ other) (len - l)

我曾希望haskell会很好,并有一些像#pragma memoize这样的好语法来帮助我,但环顾四周,通过这个fibonacci问题代码解释了解决方案。

memoized_fib :: Int -> Integer
memoized_fib = (map fib [0 ..] !!)
where fib 0 = 0
fib 1 = 1
fib n = memoized_fib (n-2) + memoized_fib (n-1)

在掌握了这个例子背后的概念后,我非常失望——使用的方法非常简单,只有当1)函数的输入是一个整数,2)函数需要按f(0), f(1), f(2), ...的顺序递归计算值时才有效。但如果我的参数是向量或集合呢如果我想记住像f(n) = f(n/2) + f(n/3)这样的函数,我需要为所有小于n的I计算f(i)的值,而我不需要这些值中的大部分(其他人指出这种说法是错误的)

我试着通过传递一个备忘录表来实现我想要的东西,我们慢慢地将其填充为一个额外的参数:

maxValue :: (Map.Map (Int, Int) Int) -> [(Int,Int)] -> Int -> (Map.Map (Int, Int) Int, Int)
maxValue m [] len = (m, 0)
maxValue m ((l, val) : other) len =
if l > len then
(mapWithSkip, skipValue)
else
(mapUnion, max skipValue (takeValue+val))
where (skipMap, skipValue) = maxValue m other len
mapWithSkip = Map.insertWith' max (1 + length other, len) skipValue skipMap
(takeMap, takeValue) = maxValue m ([(l, val)] ++ other) (len - l)
mapWithTake = Map.insertWith' max (1 + length other, len) (takeValue+val) mapWithSkip
mapUnion = Map.union mapWithSkip mapWithTake

但这太慢了,我相信因为Map.union花费的时间太长了,它是O(n+m)而不是O(min(n,m))。此外,对于像memorizaton这样简单的东西来说,这段代码似乎相当混乱。对于这个特定的问题,你可能可以将这种技巧性的方法推广到二维,并进行一些额外的计算,但我想知道如何在更普遍的意义上进行记忆。如何以这种更通用的形式实现记忆,同时保持与命令式语言中的代码相同的复杂性?

如果我想记忆f(n)=f(n/2)+f(n/3)这样的函数,我需要为所有小于n的I计算f(I)的值,而我不需要这些值中的大部分。

不,惰性意味着从不计算未使用的值。你为它们分配了一个thunk,以防它们被使用,所以这是一个非零量的CPU和RAM专用于这个未使用的值,但例如,评估f 6永远不会导致评估f 5。因此,假设计算一个项目的费用远高于分配一个cons单元格的费用,并且您最终看到的可能值占总值的很大一部分,则该方法使用的浪费工作量很小。

但是如果我的参数是向量或集合呢?

使用相同的技术,但使用与列表不同的数据结构。地图是最通用的方法,前提是您的密钥是Ord,并且您可以枚举所有需要查找的密钥。

如果您不能枚举所有的键,或者您计划查找比可能的总数少得多的键,那么您可以使用State(或ST)来模拟在函数调用之间共享可写内存缓存的必要过程。

我本想向你展示这是如何工作的,但我发现你的问题陈述/链接令人困惑。你链接到的练习似乎相当于你链接到维基百科文章中的UKP,但我在那篇文章中没有看到任何类似于你的实现的东西。";预先动态编程算法";Wikipedia给出的属性被明确设计为与您给出的fib记忆示例具有完全相同的属性。关键是一个Int,数组从左到右构建:从len=0作为基本情况开始,所有其他计算都基于已经计算的值。出于某种原因,我不明白,它似乎假设每个合法大小的对象至少有一个副本,而不是至少0个;但是,如果您有不同的约束条件,这很容易解决。

您所实现的是完全不同的,从总len开始,为每个(length, value)步骤选择要切割的length大小的多少块,然后用较小的len递归,并从权重值列表中删除前面的项目。它更接近于传统的";在给定这些面额的情况下,你能用多少种方式兑换一笔货币;问题这也适用于与fib相同的从左到右的记忆方法,但有两个维度(一个维度是要更改的货币数量,另一个维度则是剩余使用的面额数量)。

我在Haskell中进行内存化的常用方法通常是MemoTrie。它很简单,很纯粹,而且它通常能满足我的需求。

不用想太多,你就可以生产:

import Data.MemoTrie (memo2)
maxValue :: [(Int,Int)] -> Int -> Int
maxValue = memo2 go
where
go [] len = 0
go lst@((l, val):other) len =
if l > len then skipValue else max skipValue takeValue
where
skipValue = maxValue other len
takeValue = val + maxValue lst (len - l)

我没有你的输入,所以我不知道会有多快——记住[(Int,Int)]输入有点奇怪。我想你也认识到了这一点,因为在你自己的尝试中,你实际上是在列表的长度上记忆,而不是列表本身。如果你想这样做,把你的列表转换成一个恒定时间的查找数组,然后进行记忆是有意义的。这就是我想到的:

import qualified GHC.Arr as Arr
maxValue :: [(Int,Int)] -> Int -> Int
maxValue lst = memo2 go 0
where
values = Arr.listArray (0, length lst - 1) lst
go i _ | i >= length lst = 0
go i len = if l > len then skipValue else max skipValue takeValue
where
(l, val) = values Arr.! i
skipValue = go (i+1) len
takeValue = val + go i (len - l)

一般来说,Haskell中的普通内存化可以像在其他语言中一样实现,方法是在缓存值的可变映射上关闭函数的内存化版本。如果您想要像纯函数一样方便地运行函数,则需要在IO中维护状态并使用unsafePerformIO

以下备忘录可能足以用于大多数代码提交网站,因为它只依赖于通常可用的System.IO.UnsafeData.IORefData.Map.Strict

import qualified Data.Map.Strict as Map
import System.IO.Unsafe
import Data.IORef
memo :: (Ord k) => (k -> v) -> (k -> v)
memo f = unsafePerformIO $ do
m <- newIORef Map.empty
return $ k -> unsafePerformIO $ do
mv <- Map.lookup k <$> readIORef m
case mv of
Just v -> return v
Nothing -> do
let v = f k
v `seq` modifyIORef' m $ Map.insert k v
return v

从你的问题和评论来看,你似乎是那种永远失望的人(!),所以unsafePerformIO的使用可能会让你失望,但如果GHC真的提供了一个记忆语用,这可能就是它在幕后所做的。

简单使用的示例:

fib :: Int -> Int
fib = memo fib'
where fib' 0 = 0
fib' 1 = 1
fib' n = fib (n-1) + fib (n-2)
main = do
print $ fib 100000

或更多(剧透?!),一个只存储长度为的maxValue版本

maxValue :: [(Int,Int)] -> Int -> Int
maxValue values = go
where go = memo (go' values)
go' [] len = 0
go' ((l, val): other) len =
if l > len then
skipValue
else
max skipValue takeValue
where skipValue = go' other len
takeValue = val + go (len - l)

由于takeValue案例重新评估了全套可销售的产品,但它足够快,可以通过链接网页上的所有测试案例,因此这做的工作比必要的要多一些。如果速度不够快,那么你需要一个记忆器来记忆一个函数,该函数的结果在参数不相同的调用中共享(长度相同,但有不同的可销售部分,因为问题的特殊方面以及检查不同可销售部分和长度的顺序,你知道答案无论如何都是一样的)。这将是一种非标准的记忆,但修改memo函数来处理这种情况并不困难,我不认为,简单地将参数拆分为";键";论点和";非密钥";参数,或者通过在记忆时提供的任意函数从参数派生密钥。

最新更新