如何在python中加速trampolined cps版本的fib函数并支持相互递归?
How to speed up the trampolined cps version fib function and support mutual recursion in python?
我已经尝试为斐波那契函数的 cps 版本实现蹦床。但是我做不快(加缓存)支持mutual_recursion.
实现代码:
import functools
from dataclasses import dataclass
from typing import Optional, Any, Callable
START = 0
CONTINUE = 1
CONTINUE_END = 2
RETURN = 3
@dataclass
class CTX:
kind: int
result: Any # TODO ......
f: Callable
args: Optional[list]
kwargs: Optional[dict]
def trampoline(f):
ctx = CTX(START, None, None, None, None)
@functools.wraps(f)
def decorator(*args, **kwargs):
nonlocal ctx
if ctx.kind in (CONTINUE, CONTINUE_END):
ctx.args = args
ctx.kwargs = kwargs
ctx.kind = CONTINUE
return
elif ctx.kind == START:
ctx.args = args
ctx.kwargs = kwargs
ctx.kind = CONTINUE
result = None
while ctx.kind != RETURN:
args = ctx.args
kwargs = ctx.kwargs
result = f(*args, **kwargs)
if ctx.kind == CONTINUE_END:
ctx.kind = RETURN
else:
ctx.kind = CONTINUE_END
return result
return decorator
这是运行可行的例子。
@functools.lru_cache
def fib(n):
if n == 0:
return 1
elif n == 1:
return 1
else:
return fib(n - 1) + fib(n - 2)
@trampoline
def fib_cps(n, k):
if n == 0:
return k(1)
elif n == 1:
return k(1)
else:
return fib_cps(n - 1, lambda v1: fib_cps(n - 2, lambda v2: k(v1 + v2)))
def fib_cps_wrapper(n):
return fib_cps(n, lambda i:i)
@trampoline
def fib_tail(n, acc1=1, acc2=1):
if n < 2:
return acc1
else:
return fib_tail(n - 1, acc1 + acc2, acc1)
if __name__ == "__main__":
print(fib(100))
print(fib_tail(10000))
print(fib_cps_wrapper(40))
运行 号码 40
太慢了。
当 n
更大时,fib
得到 最大递归深度超过 。但是加上lru_cache
之后就快了。 iter 蹦床版本适用于递归深度并且运行非常快。
这是其他人的作品:
- 支持cps版本缓存:https://davywybiral.blogspot.com/2008/11/trampolining-for-recursion.html
- 支持mutual_recursion:https://github.com/0x65/trampoline但是太坑了,看不懂。
查看您分享的链接,有很多有趣的解决方案。我特别受到 this and changed a few things. Just a recap, you need a tail-recursive decorator that both caches results from previous executions of the function and supports mutual recursion (?). There is another interesting discussion 关于 tail-recursion 上下文中的相互递归的启发,这可能有助于您理解主要问题。
我写了一个装饰器,它既可以缓存又可以mutual-recursion:我认为它可以更进一步simplified/improved,但它适用于我选择的测试样本:
from collections import namedtuple
import functools
TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
def tail_recursive(f):
f._first_call = True
f._cache = {}
@functools.wraps(f)
def wrapper(*args, **kwargs):
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
while True:
cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
if cache_key in f._cache:
return f._cache[cache_key]
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[cache_key] = result
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
return wrapper
乍一看似乎很复杂,但它重复使用了链接中讨论的一些概念。
初始化
f._first_call = True
f._cache = {}
而不是像 START
、CONTINUE
和 RETURN
这样的状态,在这种情况下,我只需要区分 _first_call
和后面的那些。事实上,第一次调用函数后,下一次调用return一个存储参数的TailRecArgument
。
f._cache
是该特定函数的缓存。
Tail-Recursion
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
while True:
result = f(*f._new_args, **f._new_kwargs)
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
这个版本的 tail-recursion 是如何工作的?在 while
循环中,在第一次调用装饰函数后 returned 将使用新参数连续调用该函数。
我什么时候可以退出循环?一旦 returned 值不是 TailRecArguments
类型,这意味着最后一个函数调用没有递归调用自身而是 returned 一个实际值。在那种情况下,我只需要 return 结果并设置 f._first_call = True
。不幸的是,它比这复杂一点,因为它不能与相互递归一起工作。这里的修复是将调用的函数存储在 TailRecArguments
中。通过这种方式,我可以检查用于下一个循环的参数是用于同一函数 (result.wrapped_func == f
) 还是用于另一个 tail-recursive 函数。在后一种情况下,我不想处理这些参数,因为它们与另一个函数相关,相反我可以 return 它们,因为它们肯定会在第一个 [=137= 的 while
循环中执行] 遇到的功能。唯一 缺点 是每次参数属于另一个函数时都会重置 f._first_call
。
缓存
while True:
cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
if cache_key in f._cache:
return f._cache[cache_key]
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[cache_key] = result
在评论缓存机制(这是非常流行的记忆技术)之前,正确放置缓存代码很重要:注意我把它放在 while
循环中。不可能,因为只有在 while 循环内,函数才会被连续调用,我可以检查缓存命中。
我在 cache_key
的创建过程中作了一些欺骗,因为我使用了 functools
模块的内部函数。它是同一模块中 @cache
装饰器使用的那个,您可以使用
提取代码
import inspect
import functools
print(inspect.getsource(functools._make_key))
还有其他方法可以从 *args
和 **kwargs
创建缓存键,例如 this one,它再次指向 _make_key
的实现。为了让你的代码更稳定,当然要避免使用私有成员。
正如我所说,剩下的就是记忆,还有一个额外的检查:if not isinstance(result, TailRecArguments): ...
。我想缓存值,而不是 tail-recursive 调用的参数。
(实际上,我认为您可以暂时将所有 TailRecArguments
存储在一个列表中,并在缓存中添加与该列表大小一样多的条目,当实际值被 return 编辑时一个递归调用。它会使解决方案复杂化,但如果您有性能问题,仍然可以接受。这可能会在相互递归的情况下引发一些错误,如果需要,我将进行处理)。
测试
这些是我用来测试装饰器的几个基本功能:
@tail_recursive
def even(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> even(100)
True
>>> even(101)
False
"""
return True if n == 0 else odd(n - 1)
@tail_recursive
def odd(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> odd(100)
False
>>> odd(101)
True
"""
return False if n == 0 else even(n - 1)
@tail_recursive
def fact(n, acc=1):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> fact(30)
265252859812191058636308480000000
"""
return acc if n <= 1 else fact(n - 1, acc * n)
@tail_recursive
def fib(n, a = 0, b = 1):
"""
>>> import sys
>>> sys.setrecursionlimit(20)
>>> fib(30)
832040
"""
return a if n == 0 else b if n == 1 else fib(n - 1, b, a + b)
if __name__ == '__main__':
import doctest
doctest.testmod()
请注意,缓存在这些示例中不是很有用,以阶乘为例:fact(10)
永远不会使用 fact(8)
,事实上
fact(8)
fact(10)
fact(10, 1)
fact(9, 10)
fact(8, 1)
fact(8, 90)
...
...
累加器是缓存键的一部分,因此您应该通过自定义要缓存的参数来更改缓存策略(同样,如果需要,我也可以为此提出解决方案)。
更新 - 缓存优化
这是对原始答案中使用的缓存策略的部分修复。主要问题是,考虑到通用 tail-recursive 算法的工作原理(参见阶乘示例),将所有参数包含在缓存键中效率低下。
第一个可能的优化是让用户选择哪些参数用于键,哪些参数用于值。由于类型提示,它的可读性要差得多,但测试让一切变得更加清晰:
class Logger:
def __init__(self, name):
self._name = name
self._entries = []
def log(self, s):
self._entries.append(s)
def print(self):
log_prefix = f"[{self._name}] - "
print(log_prefix + f"\n{log_prefix}".join(self._entries))
TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
default_logger = Logger('default')
def tail_recursive(logger: Logger = default_logger, \
get_cache_key: Callable[[Iterable, Dict], Hashable] = lambda args, kwargs: \
functools._make_key(args, kwargs, False),\
get_result_after_cache_hit: Callable[[Any, Iterable, Dict], Any] = lambda value, args, kwargs: \
value):
def decorator(f):
f._first_call = True
f._cache = {}
@functools.wraps(f)
def wrapper(*args, **kwargs):
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
f._initial_key = get_cache_key(f._new_args, f._new_kwargs)
while True:
cache_key = get_cache_key(f._new_args, f._new_kwargs)
if cache_key in f._cache:
logger.log('cache hit for ' + str(cache_key))
return get_result_after_cache_hit(f._cache[cache_key], f._new_args, f._new_kwargs)
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[f._initial_key] = result
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
return wrapper
return decorator
除了仅用于确认缓存命中的 Logger
class 之外,主要区别在于每个函数现在都有一个名为 _initial_key
的新成员,用于存储密钥的第一个电话。这样,如果我调用 fact(5)
,5
就变成 _initial_key
,结果放在 f._cache[5]
.
这样可以优化相互递归和tail-recursive函数,但在某些情况下效果不佳。让我们从最好的情况开始:
fact_logger = Logger('fact')
@tail_recursive(logger=fact_logger, get_cache_key=lambda args, kwargs: args[0],\
get_result_after_cache_hit=lambda value, args, kwargs: value * args[1])
def fact(n, acc=1):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> fact(5)
120
>>> fact(30)
265252859812191058636308480000000
>>> fact_logger.print()
[fact] - cache hit for 5
"""
return acc if n <= 1 else fact(n - 1, acc * n)
@tail_recursive
装饰器初始化包括(记录器)get_cache_key
指定只有第一个参数 n
应该是缓存键的一部分和 get_result_after_cache_hit
指定如何在之后产生最终结果缓存命中。在上面的例子中,当 fact(30)
达到 fact(5, <partial_factorial>)
时,结果立即计算为 <partial_factorial> * f._cache[5]
.
even-odd
也是如此,只是在这种情况下 tail_recursive
的默认参数绰绰有余:
even_logger = Logger('even')
@tail_recursive(logger=even_logger)
def even(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> even(100)
True
>>> even(101)
False
>>> even(104)
True
>>> even_logger.print()
[even] - cache hit for 100
"""
return True if n == 0 else odd(n - 1)
不幸的是,这不适用于斐波那契函数。您应该通过在每次调用期间打印参数来轻松说服自己,结果如下所示:
30 0 1
29 1 1
28 1 2
27 2 3
26 3 5
25 5 8
...
建立缓存键规则需要一个更复杂的逻辑,这可能会使 tail_recursive
装饰器变得非常不可读且可移植性较差。
我已经尝试为斐波那契函数的 cps 版本实现蹦床。但是我做不快(加缓存)支持mutual_recursion.
实现代码:
import functools
from dataclasses import dataclass
from typing import Optional, Any, Callable
START = 0
CONTINUE = 1
CONTINUE_END = 2
RETURN = 3
@dataclass
class CTX:
kind: int
result: Any # TODO ......
f: Callable
args: Optional[list]
kwargs: Optional[dict]
def trampoline(f):
ctx = CTX(START, None, None, None, None)
@functools.wraps(f)
def decorator(*args, **kwargs):
nonlocal ctx
if ctx.kind in (CONTINUE, CONTINUE_END):
ctx.args = args
ctx.kwargs = kwargs
ctx.kind = CONTINUE
return
elif ctx.kind == START:
ctx.args = args
ctx.kwargs = kwargs
ctx.kind = CONTINUE
result = None
while ctx.kind != RETURN:
args = ctx.args
kwargs = ctx.kwargs
result = f(*args, **kwargs)
if ctx.kind == CONTINUE_END:
ctx.kind = RETURN
else:
ctx.kind = CONTINUE_END
return result
return decorator
这是运行可行的例子。
@functools.lru_cache
def fib(n):
if n == 0:
return 1
elif n == 1:
return 1
else:
return fib(n - 1) + fib(n - 2)
@trampoline
def fib_cps(n, k):
if n == 0:
return k(1)
elif n == 1:
return k(1)
else:
return fib_cps(n - 1, lambda v1: fib_cps(n - 2, lambda v2: k(v1 + v2)))
def fib_cps_wrapper(n):
return fib_cps(n, lambda i:i)
@trampoline
def fib_tail(n, acc1=1, acc2=1):
if n < 2:
return acc1
else:
return fib_tail(n - 1, acc1 + acc2, acc1)
if __name__ == "__main__":
print(fib(100))
print(fib_tail(10000))
print(fib_cps_wrapper(40))
运行 号码 40
太慢了。
当 n
更大时,fib
得到 最大递归深度超过 。但是加上lru_cache
之后就快了。 iter 蹦床版本适用于递归深度并且运行非常快。
这是其他人的作品:
- 支持cps版本缓存:https://davywybiral.blogspot.com/2008/11/trampolining-for-recursion.html
- 支持mutual_recursion:https://github.com/0x65/trampoline但是太坑了,看不懂。
查看您分享的链接,有很多有趣的解决方案。我特别受到 this and changed a few things. Just a recap, you need a tail-recursive decorator that both caches results from previous executions of the function and supports mutual recursion (?). There is another interesting discussion 关于 tail-recursion 上下文中的相互递归的启发,这可能有助于您理解主要问题。
我写了一个装饰器,它既可以缓存又可以mutual-recursion:我认为它可以更进一步simplified/improved,但它适用于我选择的测试样本:
from collections import namedtuple
import functools
TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
def tail_recursive(f):
f._first_call = True
f._cache = {}
@functools.wraps(f)
def wrapper(*args, **kwargs):
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
while True:
cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
if cache_key in f._cache:
return f._cache[cache_key]
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[cache_key] = result
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
return wrapper
乍一看似乎很复杂,但它重复使用了链接中讨论的一些概念。
初始化
f._first_call = True
f._cache = {}
而不是像 START
、CONTINUE
和 RETURN
这样的状态,在这种情况下,我只需要区分 _first_call
和后面的那些。事实上,第一次调用函数后,下一次调用return一个存储参数的TailRecArgument
。
f._cache
是该特定函数的缓存。
Tail-Recursion
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
while True:
result = f(*f._new_args, **f._new_kwargs)
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
这个版本的 tail-recursion 是如何工作的?在 while
循环中,在第一次调用装饰函数后 returned 将使用新参数连续调用该函数。
我什么时候可以退出循环?一旦 returned 值不是 TailRecArguments
类型,这意味着最后一个函数调用没有递归调用自身而是 returned 一个实际值。在那种情况下,我只需要 return 结果并设置 f._first_call = True
。不幸的是,它比这复杂一点,因为它不能与相互递归一起工作。这里的修复是将调用的函数存储在 TailRecArguments
中。通过这种方式,我可以检查用于下一个循环的参数是用于同一函数 (result.wrapped_func == f
) 还是用于另一个 tail-recursive 函数。在后一种情况下,我不想处理这些参数,因为它们与另一个函数相关,相反我可以 return 它们,因为它们肯定会在第一个 [=137= 的 while
循环中执行] 遇到的功能。唯一 缺点 是每次参数属于另一个函数时都会重置 f._first_call
。
缓存
while True:
cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
if cache_key in f._cache:
return f._cache[cache_key]
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[cache_key] = result
在评论缓存机制(这是非常流行的记忆技术)之前,正确放置缓存代码很重要:注意我把它放在 while
循环中。不可能,因为只有在 while 循环内,函数才会被连续调用,我可以检查缓存命中。
我在 cache_key
的创建过程中作了一些欺骗,因为我使用了 functools
模块的内部函数。它是同一模块中 @cache
装饰器使用的那个,您可以使用
import inspect
import functools
print(inspect.getsource(functools._make_key))
还有其他方法可以从 *args
和 **kwargs
创建缓存键,例如 this one,它再次指向 _make_key
的实现。为了让你的代码更稳定,当然要避免使用私有成员。
正如我所说,剩下的就是记忆,还有一个额外的检查:if not isinstance(result, TailRecArguments): ...
。我想缓存值,而不是 tail-recursive 调用的参数。
(实际上,我认为您可以暂时将所有 TailRecArguments
存储在一个列表中,并在缓存中添加与该列表大小一样多的条目,当实际值被 return 编辑时一个递归调用。它会使解决方案复杂化,但如果您有性能问题,仍然可以接受。这可能会在相互递归的情况下引发一些错误,如果需要,我将进行处理)。
测试
这些是我用来测试装饰器的几个基本功能:
@tail_recursive
def even(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> even(100)
True
>>> even(101)
False
"""
return True if n == 0 else odd(n - 1)
@tail_recursive
def odd(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> odd(100)
False
>>> odd(101)
True
"""
return False if n == 0 else even(n - 1)
@tail_recursive
def fact(n, acc=1):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> fact(30)
265252859812191058636308480000000
"""
return acc if n <= 1 else fact(n - 1, acc * n)
@tail_recursive
def fib(n, a = 0, b = 1):
"""
>>> import sys
>>> sys.setrecursionlimit(20)
>>> fib(30)
832040
"""
return a if n == 0 else b if n == 1 else fib(n - 1, b, a + b)
if __name__ == '__main__':
import doctest
doctest.testmod()
请注意,缓存在这些示例中不是很有用,以阶乘为例:fact(10)
永远不会使用 fact(8)
,事实上
fact(8) |
fact(10) |
---|---|
fact(10, 1) | |
fact(9, 10) | |
fact(8, 1) | fact(8, 90) |
... | ... |
累加器是缓存键的一部分,因此您应该通过自定义要缓存的参数来更改缓存策略(同样,如果需要,我也可以为此提出解决方案)。
更新 - 缓存优化
这是对原始答案中使用的缓存策略的部分修复。主要问题是,考虑到通用 tail-recursive 算法的工作原理(参见阶乘示例),将所有参数包含在缓存键中效率低下。
第一个可能的优化是让用户选择哪些参数用于键,哪些参数用于值。由于类型提示,它的可读性要差得多,但测试让一切变得更加清晰:
class Logger:
def __init__(self, name):
self._name = name
self._entries = []
def log(self, s):
self._entries.append(s)
def print(self):
log_prefix = f"[{self._name}] - "
print(log_prefix + f"\n{log_prefix}".join(self._entries))
TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
default_logger = Logger('default')
def tail_recursive(logger: Logger = default_logger, \
get_cache_key: Callable[[Iterable, Dict], Hashable] = lambda args, kwargs: \
functools._make_key(args, kwargs, False),\
get_result_after_cache_hit: Callable[[Any, Iterable, Dict], Any] = lambda value, args, kwargs: \
value):
def decorator(f):
f._first_call = True
f._cache = {}
@functools.wraps(f)
def wrapper(*args, **kwargs):
if f._first_call:
f._new_args = args
f._new_kwargs = kwargs
try:
f._first_call = False
f._initial_key = get_cache_key(f._new_args, f._new_kwargs)
while True:
cache_key = get_cache_key(f._new_args, f._new_kwargs)
if cache_key in f._cache:
logger.log('cache hit for ' + str(cache_key))
return get_result_after_cache_hit(f._cache[cache_key], f._new_args, f._new_kwargs)
result = f(*f._new_args, **f._new_kwargs)
if not isinstance(result, TailRecArguments):
f._cache[f._initial_key] = result
if isinstance(result, TailRecArguments) and result.wrapped_func == f:
f._new_args = result.args
f._new_kwargs = result.kwargs
else:
break
return result
finally:
f._first_call = True
else:
return TailRecArguments(f, args, kwargs)
return wrapper
return decorator
除了仅用于确认缓存命中的 Logger
class 之外,主要区别在于每个函数现在都有一个名为 _initial_key
的新成员,用于存储密钥的第一个电话。这样,如果我调用 fact(5)
,5
就变成 _initial_key
,结果放在 f._cache[5]
.
这样可以优化相互递归和tail-recursive函数,但在某些情况下效果不佳。让我们从最好的情况开始:
fact_logger = Logger('fact')
@tail_recursive(logger=fact_logger, get_cache_key=lambda args, kwargs: args[0],\
get_result_after_cache_hit=lambda value, args, kwargs: value * args[1])
def fact(n, acc=1):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> fact(5)
120
>>> fact(30)
265252859812191058636308480000000
>>> fact_logger.print()
[fact] - cache hit for 5
"""
return acc if n <= 1 else fact(n - 1, acc * n)
@tail_recursive
装饰器初始化包括(记录器)get_cache_key
指定只有第一个参数 n
应该是缓存键的一部分和 get_result_after_cache_hit
指定如何在之后产生最终结果缓存命中。在上面的例子中,当 fact(30)
达到 fact(5, <partial_factorial>)
时,结果立即计算为 <partial_factorial> * f._cache[5]
.
even-odd
也是如此,只是在这种情况下 tail_recursive
的默认参数绰绰有余:
even_logger = Logger('even')
@tail_recursive(logger=even_logger)
def even(n):
"""
>>> import sys
>>> sys.setrecursionlimit(30)
>>> even(100)
True
>>> even(101)
False
>>> even(104)
True
>>> even_logger.print()
[even] - cache hit for 100
"""
return True if n == 0 else odd(n - 1)
不幸的是,这不适用于斐波那契函数。您应该通过在每次调用期间打印参数来轻松说服自己,结果如下所示:
30 0 1
29 1 1
28 1 2
27 2 3
26 3 5
25 5 8
...
建立缓存键规则需要一个更复杂的逻辑,这可能会使 tail_recursive
装饰器变得非常不可读且可移植性较差。