如何加快被践踏的cps版本fib函数的速度,并支持python中的相互递归



我尝试为fibonacci函数的cps版本实现蹦床。但是我不能使它快速(添加缓存)并支持mutual_recurrence。

机具代码:

import functools
from dataclasses import dataclass
from typing import Optional, Any, Callable
START = 0
CONTINUE = 1
CONTINUE_END = 2
RETURN = 3

@dataclass
class CTX:
kind: int
result: Any    # TODO ......
f: Callable
args: Optional[list]
kwargs: Optional[dict]

def trampoline(f):
ctx = CTX(START, None, None, None, None)
@functools.wraps(f)
def decorator(*args, **kwargs):
nonlocal ctx
if ctx.kind in (CONTINUE, CONTINUE_END):
ctx.args = args
ctx.kwargs = kwargs
ctx.kind = CONTINUE
return
elif ctx.kind == START:
ctx.args = args
ctx.kwargs = kwargs
ctx.kind = CONTINUE
result = None
while ctx.kind != RETURN:
args = ctx.args
kwargs = ctx.kwargs
result = f(*args, **kwargs)
if ctx.kind == CONTINUE_END:
ctx.kind = RETURN
else:
ctx.kind = CONTINUE_END
return result
return decorator

以下是可运行的示例。

@functools.lru_cache
def fib(n):
if n == 0:
return 1
elif n == 1:
return 1
else:
return fib(n - 1) + fib(n - 2)
@trampoline
def fib_cps(n, k):
if n == 0:
return k(1)
elif n == 1:
return k(1)
else:
return fib_cps(n - 1, lambda v1: fib_cps(n - 2, lambda v2: k(v1 + v2)))
def fib_cps_wrapper(n):
return fib_cps(n, lambda i:i)

@trampoline
def fib_tail(n, acc1=1, acc2=1):
if n < 2:
return acc1
else:
return fib_tail(n - 1, acc1 + acc2, acc1)

if __name__ == "__main__":
print(fib(100))
print(fib_tail(10000))
print(fib_cps_wrapper(40))

运行数字40太慢。当n较大时,fib得到的最大递归深度超过了。但添加lru_cache后会很快。iter蹦床版本可以进行递归深度,运行速度非常快。

以下是其他人的工作:

  1. 支持cps版本缓存:https://davywybiral.blogspot.com/2008/11/trampolining-for-recursion.html
  2. 支持互斥递归:https://github.com/0x65/trampoline但这太难理解了

看看您共享的链接,有很多有趣的解决方案。我特别受此启发,并改变了一些事情。简单回顾一下,您需要一个尾部递归装饰器,它既可以缓存函数先前执行的结果,又支持相互递归(?)。关于尾部递归上下文中的相互递归,还有另一个有趣的讨论,这可能有助于您理解主要问题。


我已经编写了一个装饰器,它同时进行缓存和相互递归:我认为它可以进一步简化/改进,但它适用于我选择的测试样本:

from collections import namedtuple
import functools
TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
def tail_recursive(f):
f._first_call = True
f._cache = {}
@functools.wraps(f)
def wrapper(*args, **kwargs):
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs

try:
f._first_call = False
while True:
cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
if cache_key in f._cache:
return f._cache[cache_key]
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[cache_key] = result
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
return wrapper

乍一看,它似乎相当复杂,但它重复使用了链接中讨论的一些概念。


初始化

f._first_call = True
f._cache = {}

在这种情况下,我只需要区分_first_call和以下状态,而不是像STARTCONTINUERETURN这样的状态。事实上,在第一次调用函数之后,接下来的调用将返回一个存储参数的TailRecArgument

f._cache是用于该特定功能的高速缓存。


尾部递归

if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
while True:
result = f(*f._new_args, **f._new_kwargs)
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)

这个版本的尾部递归是如何工作的?在while循环中,函数被连续调用,在第一次调用修饰函数后返回新的参数。

我什么时候可以退出循环?一旦返回的值不是TailRecArguments类型,这意味着上一个函数调用没有递归调用自己,而是返回了一个实际值。在这种情况下,我只需要返回结果并设置f._first_call = True。不幸的是,它比这复杂一点,因为它不适用于相互递归。这里的修复方法是将调用的函数存储在TailRecArguments中。通过这种方式,我可以检查用于下一个循环的参数是针对同一个函数(result.wrapped_func == f)还是针对另一个尾部递归函数。在后一种情况下,我不想处理这些参数,因为它们与另一个函数有关,相反,我可以返回它们,因为它们肯定会在遇到的第一个尾部递归函数的while循环中执行。唯一的缺点是每次参数属于另一个函数时,f._first_call都会重置。


缓存

while True:
cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
if cache_key in f._cache:
return f._cache[cache_key]
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[cache_key] = result

在评论缓存机制(这是一种非常流行的内存化技术)之前,正确放置缓存代码是很重要的:请注意,我把它放在了while循环中。否则就不可能了,因为只有在while循环中,函数才会被连续调用,我可以检查缓存命中率。

我在创建cache_key时有点作弊,因为我使用了functools模块的内部函数。它是@cache装饰器在同一模块中使用的,您可以使用提取代码

import inspect
import functools
print(inspect.getsource(functools._make_key))

还有其他方法可以从*args**kwargs创建缓存密钥,比如这一种,它再次指向_make_key的实现。为了使代码更加稳定,当然要避免使用私有成员。

正如我所说,剩下的是记忆,还有一个额外的检查:if not isinstance(result, TailRecArguments): ...。我想要缓存值,而不是尾部递归调用的参数。

(事实上,我认为当递归调用返回实际值时,你可以将所有TailRecArguments临时存储在一个列表中,并在缓存中添加与该列表大小相同的条目。这会使解决方案复杂化,但如果你有性能问题,仍然可以接受。这可能会在相互递归的情况下引发一些错误,如果需要,我会处理它)。


测试

以下是我用来测试装饰器的几个基本函数:

@tail_recursive
def even(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> even(100)
True
>>> even(101)
False
"""
return True if n == 0 else odd(n - 1)
@tail_recursive
def odd(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> odd(100)
False
>>> odd(101)
True
"""
return False if n == 0 else even(n - 1)
@tail_recursive
def fact(n, acc=1):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> fact(30)
265252859812191058636308480000000
"""
return acc if n <= 1 else fact(n - 1, acc * n)
@tail_recursive
def fib(n, a = 0, b = 1):
"""
>>> import sys
>>> sys.setrecursionlimit(20)
>>> fib(30)
832040
"""
return a if n == 0 else b if n == 1 else fib(n - 1, b, a + b)
if __name__ == '__main__':
import doctest
doctest.testmod()

请注意,缓存在这些示例中不是很有用,以factorial为例:fact(10)永远不会使用fact(8),事实上是

fact(8)fact(10)
事实(10,1)
事实(9,10)
fact(8,1)fact(8,90)

相关内容

  • 没有找到相关文章

最新更新