我尝试为fibonacci函数的cps版本实现蹦床。但是我不能使它快速(添加缓存)并支持mutual_recurrence。
机具代码:
import functools
from dataclasses import dataclass
from typing import Optional, Any, Callable
START = 0
CONTINUE = 1
CONTINUE_END = 2
RETURN = 3
@dataclass
class CTX:
kind: int
result: Any # TODO ......
f: Callable
args: Optional[list]
kwargs: Optional[dict]
def trampoline(f):
ctx = CTX(START, None, None, None, None)
@functools.wraps(f)
def decorator(*args, **kwargs):
nonlocal ctx
if ctx.kind in (CONTINUE, CONTINUE_END):
ctx.args = args
ctx.kwargs = kwargs
ctx.kind = CONTINUE
return
elif ctx.kind == START:
ctx.args = args
ctx.kwargs = kwargs
ctx.kind = CONTINUE
result = None
while ctx.kind != RETURN:
args = ctx.args
kwargs = ctx.kwargs
result = f(*args, **kwargs)
if ctx.kind == CONTINUE_END:
ctx.kind = RETURN
else:
ctx.kind = CONTINUE_END
return result
return decorator
以下是可运行的示例。
@functools.lru_cache
def fib(n):
if n == 0:
return 1
elif n == 1:
return 1
else:
return fib(n - 1) + fib(n - 2)
@trampoline
def fib_cps(n, k):
if n == 0:
return k(1)
elif n == 1:
return k(1)
else:
return fib_cps(n - 1, lambda v1: fib_cps(n - 2, lambda v2: k(v1 + v2)))
def fib_cps_wrapper(n):
return fib_cps(n, lambda i:i)
@trampoline
def fib_tail(n, acc1=1, acc2=1):
if n < 2:
return acc1
else:
return fib_tail(n - 1, acc1 + acc2, acc1)
if __name__ == "__main__":
print(fib(100))
print(fib_tail(10000))
print(fib_cps_wrapper(40))
运行数字40
太慢。当n
较大时,fib
得到的最大递归深度超过了。但添加lru_cache
后会很快。iter蹦床版本可以进行递归深度,运行速度非常快。
以下是其他人的工作:
- 支持cps版本缓存:https://davywybiral.blogspot.com/2008/11/trampolining-for-recursion.html
- 支持互斥递归:https://github.com/0x65/trampoline但这太难理解了
看看您共享的链接,有很多有趣的解决方案。我特别受此启发,并改变了一些事情。简单回顾一下,您需要一个尾部递归装饰器,它既可以缓存函数先前执行的结果,又支持相互递归(?)。关于尾部递归上下文中的相互递归,还有另一个有趣的讨论,这可能有助于您理解主要问题。
我已经编写了一个装饰器,它同时进行缓存和相互递归:我认为它可以进一步简化/改进,但它适用于我选择的测试样本:
from collections import namedtuple
import functools
TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
def tail_recursive(f):
f._first_call = True
f._cache = {}
@functools.wraps(f)
def wrapper(*args, **kwargs):
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
while True:
cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
if cache_key in f._cache:
return f._cache[cache_key]
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[cache_key] = result
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
return wrapper
乍一看,它似乎相当复杂,但它重复使用了链接中讨论的一些概念。
初始化
f._first_call = True
f._cache = {}
在这种情况下,我只需要区分_first_call
和以下状态,而不是像START
、CONTINUE
和RETURN
这样的状态。事实上,在第一次调用函数之后,接下来的调用将返回一个存储参数的TailRecArgument
。
f._cache
是用于该特定功能的高速缓存。
尾部递归
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
while True:
result = f(*f._new_args, **f._new_kwargs)
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
这个版本的尾部递归是如何工作的?在while
循环中,函数被连续调用,在第一次调用修饰函数后返回新的参数。
我什么时候可以退出循环?一旦返回的值不是TailRecArguments
类型,这意味着上一个函数调用没有递归调用自己,而是返回了一个实际值。在这种情况下,我只需要返回结果并设置f._first_call = True
。不幸的是,它比这复杂一点,因为它不适用于相互递归。这里的修复方法是将调用的函数存储在TailRecArguments
中。通过这种方式,我可以检查用于下一个循环的参数是针对同一个函数(result.wrapped_func == f
)还是针对另一个尾部递归函数。在后一种情况下,我不想处理这些参数,因为它们与另一个函数有关,相反,我可以返回它们,因为它们肯定会在遇到的第一个尾部递归函数的while
循环中执行。唯一的缺点是每次参数属于另一个函数时,f._first_call
都会重置。
缓存
while True:
cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
if cache_key in f._cache:
return f._cache[cache_key]
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[cache_key] = result
在评论缓存机制(这是一种非常流行的内存化技术)之前,正确放置缓存代码是很重要的:请注意,我把它放在了while
循环中。否则就不可能了,因为只有在while循环中,函数才会被连续调用,我可以检查缓存命中率。
我在创建cache_key
时有点作弊,因为我使用了functools
模块的内部函数。它是@cache
装饰器在同一模块中使用的,您可以使用提取代码
import inspect
import functools
print(inspect.getsource(functools._make_key))
还有其他方法可以从*args
和**kwargs
创建缓存密钥,比如这一种,它再次指向_make_key
的实现。为了使代码更加稳定,当然要避免使用私有成员。
正如我所说,剩下的是记忆,还有一个额外的检查:if not isinstance(result, TailRecArguments): ...
。我想要缓存值,而不是尾部递归调用的参数。
(事实上,我认为当递归调用返回实际值时,你可以将所有TailRecArguments
临时存储在一个列表中,并在缓存中添加与该列表大小相同的条目。这会使解决方案复杂化,但如果你有性能问题,仍然可以接受。这可能会在相互递归的情况下引发一些错误,如果需要,我会处理它)。
测试
以下是我用来测试装饰器的几个基本函数:
@tail_recursive
def even(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> even(100)
True
>>> even(101)
False
"""
return True if n == 0 else odd(n - 1)
@tail_recursive
def odd(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> odd(100)
False
>>> odd(101)
True
"""
return False if n == 0 else even(n - 1)
@tail_recursive
def fact(n, acc=1):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> fact(30)
265252859812191058636308480000000
"""
return acc if n <= 1 else fact(n - 1, acc * n)
@tail_recursive
def fib(n, a = 0, b = 1):
"""
>>> import sys
>>> sys.setrecursionlimit(20)
>>> fib(30)
832040
"""
return a if n == 0 else b if n == 1 else fib(n - 1, b, a + b)
if __name__ == '__main__':
import doctest
doctest.testmod()
请注意,缓存在这些示例中不是很有用,以factorial为例:fact(10)
永远不会使用fact(8)
,事实上是
fact(8) | fact(10) |
---|---|
事实(10,1) | |
事实(9,10) | |
fact(8,1) | fact(8,90) |