如何在python中加速trampolined cps版本的fib函数并支持相互递归?

How to speed up the trampolined cps version fib function and support mutual recursion in python?

我已经尝试为斐波那契函数的 cps 版本实现蹦床。但是我做不快(加缓存)支持mutual_recursion.

实现代码:

import functools
from dataclasses import dataclass
from typing import Optional, Any, Callable

START = 0
CONTINUE = 1
CONTINUE_END = 2
RETURN = 3


@dataclass
class CTX:
    kind: int
    result: Any    # TODO ......
    f: Callable
    args: Optional[list]
    kwargs: Optional[dict]


def trampoline(f):
    ctx = CTX(START, None, None, None, None)

    @functools.wraps(f)
    def decorator(*args, **kwargs):
        nonlocal ctx
        if ctx.kind in (CONTINUE, CONTINUE_END):
            ctx.args = args
            ctx.kwargs = kwargs
            ctx.kind = CONTINUE
            return
        elif ctx.kind == START:
            ctx.args = args
            ctx.kwargs = kwargs
            ctx.kind = CONTINUE

        result = None
        while ctx.kind != RETURN:
            args = ctx.args
            kwargs = ctx.kwargs
            result = f(*args, **kwargs)
            if ctx.kind == CONTINUE_END:
                ctx.kind = RETURN
            else:
                ctx.kind = CONTINUE_END

        return result

    return decorator

这是运行可行的例子。

@functools.lru_cache
def fib(n):
    if n == 0:
        return 1
    elif n == 1:
        return 1
    else:
        return fib(n - 1) + fib(n - 2)

@trampoline
def fib_cps(n, k):
    if n == 0:
        return k(1)
    elif n == 1:
        return k(1)
    else:
        return fib_cps(n - 1, lambda v1: fib_cps(n - 2, lambda v2: k(v1 + v2)))

def fib_cps_wrapper(n):
    return fib_cps(n, lambda i:i)


@trampoline
def fib_tail(n, acc1=1, acc2=1):
    if n < 2:
        return acc1
    else:
        return fib_tail(n - 1, acc1 + acc2, acc1)


if __name__ == "__main__":
    print(fib(100))
    print(fib_tail(10000))
    print(fib_cps_wrapper(40))

运行 号码 40 太慢了。 当 n 更大时,fib 得到 最大递归深度超过 。但是加上lru_cache之后就快了。 iter 蹦床版本适用于递归深度并且运行非常快。

这是其他人的作品:

  1. 支持cps版本缓存:https://davywybiral.blogspot.com/2008/11/trampolining-for-recursion.html
  2. 支持mutual_recursion:https://github.com/0x65/trampoline但是太坑了,看不懂。

查看您分享的链接,有很多有趣的解决方案。我特别受到 this and changed a few things. Just a recap, you need a tail-recursive decorator that both caches results from previous executions of the function and supports mutual recursion (?). There is another interesting discussion 关于 tail-recursion 上下文中的相互递归的启发,这可能有助于您理解主要问题。


我写了一个装饰器,它既可以缓存又可以mutual-recursion:我认为它可以更进一步simplified/improved,但它适用于我选择的测试样本:

from collections import namedtuple
import functools

TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
def tail_recursive(f):
    f._first_call = True
    f._cache = {}

    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        if f._first_call:
            f._new_args = args
            f._new_kwargs = kwargs
        
            try:
                f._first_call = False
                while True:
                    cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
                    if cache_key in f._cache:
                        return f._cache[cache_key]

                    result = f(*f._new_args, **f._new_kwargs)

                    if not isinstance(result, TailRecArguments):
                        f._cache[cache_key] = result

                    if isinstance(result, TailRecArguments) and result.wrapped_func == f:
                        f._new_args = result.args
                        f._new_kwargs = result.kwargs
                    else:
                        break

                return result
            finally:
                f._first_call = True
        else:
            return TailRecArguments(f, args, kwargs)

    return wrapper

乍一看似乎很复杂,但它重复使用了链接中讨论的一些概念。


初始化

f._first_call = True
f._cache = {}

而不是像 STARTCONTINUERETURN 这样的状态,在这种情况下,我只需要区分 _first_call 和后面的那些。事实上,第一次调用函数后,下一次调用return一个存储参数的TailRecArgument

f._cache 是该特定函数的缓存。


Tail-Recursion

if f._first_call:
    f._new_args = args
    f._new_kwargs = kwargs

    try:
        f._first_call = False
        while True:
            result = f(*f._new_args, **f._new_kwargs)

            if isinstance(result, TailRecArguments) and result.wrapped_func == f:
                f._new_args = result.args
                f._new_kwargs = result.kwargs
            else:
                break

        return result
    finally:
        f._first_call = True
else:
    return TailRecArguments(f, args, kwargs)

这个版本的 tail-recursion 是如何工作的?在 while 循环中,在第一次调用装饰函数后 returned 将使用新参数连续调用该函数。

我什么时候可以退出循环?一旦 returned 值不是 TailRecArguments 类型,这意味着最后一个函数调用没有递归调用自身而是 returned 一个实际值。在那种情况下,我只需要 return 结果并设置 f._first_call = True。不幸的是,它比这复杂一点,因为它不能与相互递归一起工作。这里的修复是将调用的函数存储在 TailRecArguments 中。通过这种方式,我可以检查用于下一个循环的参数是用于同一函数 (result.wrapped_func == f) 还是用于另一个 tail-recursive 函数。在后一种情况下,我不想处理这些参数,因为它们与另一个函数相关,相反我可以 return 它们,因为它们肯定会在第一个 [=137= 的 while 循环中执行] 遇到的功能。唯一 缺点 是每次参数属于另一个函数时都会重置 f._first_call


缓存

while True:
    cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
    if cache_key in f._cache:
        return f._cache[cache_key]

    result = f(*f._new_args, **f._new_kwargs)

    if not isinstance(result, TailRecArguments):
        f._cache[cache_key] = result

在评论缓存机制(这是非常流行的记忆技术)之前,正确放置缓存代码很重要:注意我把它放在 while 循环中。不可能,因为只有在 while 循环内,函数才会被连续调用,我可以检查缓存命中。

我在 cache_key 的创建过程中作了一些欺骗,因为我使用了 functools 模块的内部函数。它是同一模块中 @cache 装饰器使用的那个,您可以使用

提取代码
import inspect
import functools
print(inspect.getsource(functools._make_key))

还有其他方法可以从 *args**kwargs 创建缓存键,例如 this one,它再次指向 _make_key 的实现。为了让你的代码更稳定,当然要避免使用私有成员。

正如我所说,剩下的就是记忆,还有一个额外的检查:if not isinstance(result, TailRecArguments): ...。我想缓存值,而不是 tail-recursive 调用的参数。

(实际上,我认为您可以暂时将所有 TailRecArguments 存储在一个列表中,并在缓存中添加与该列表大小一样多的条目,当实际值被 return 编辑时一个递归调用。它会使解决方案复杂化,但如果您有性能问题,仍然可以接受。这可能会在相互递归的情况下引发一些错误,如果需要,我将进行处理)。


测试

这些是我用来测试装饰器的几个基本功能:

@tail_recursive
def even(n):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> even(100)
    True
    >>> even(101)
    False
    """
    return True if n == 0 else odd(n - 1)

@tail_recursive
def odd(n):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> odd(100)
    False
    >>> odd(101)
    True
    """
    return False if n == 0 else even(n - 1)

@tail_recursive
def fact(n, acc=1):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> fact(30)
    265252859812191058636308480000000
    """
    return acc if n <= 1 else fact(n - 1, acc * n)

@tail_recursive
def fib(n, a = 0, b = 1):
    """
    >>> import sys
    >>> sys.setrecursionlimit(20)
    >>> fib(30)
    832040
    """
    return a if n == 0 else b if n == 1 else fib(n - 1, b, a + b)

if __name__ == '__main__':
    import doctest
    doctest.testmod()

请注意,缓存在这些示例中不是很有用,以阶乘为例:fact(10) 永远不会使用 fact(8),事实上

fact(8) fact(10)
fact(10, 1)
fact(9, 10)
fact(8, 1) fact(8, 90)
... ...

累加器是缓存键的一部分,因此您应该通过自定义要缓存的参数来更改缓存策略(同样,如果需要,我也可以为此提出解决方案)。


更新 - 缓存优化

这是对原始答案中使用的缓存策略的部分修复。主要问题是,考虑到通用 tail-recursive 算法的工作原理(参见阶乘示例),将所有参数包含在缓存键中效率低下。

第一个可能的优化是让用户选择哪些参数用于键,哪些参数用于值。由于类型提示,它的可读性要差得多,但测试让一切变得更加清晰:

class Logger:
    def __init__(self, name):
        self._name = name
        self._entries = []
    
    def log(self, s):
        self._entries.append(s)

    def print(self):
        log_prefix = f"[{self._name}] - "
        print(log_prefix + f"\n{log_prefix}".join(self._entries))

TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
default_logger = Logger('default')
def tail_recursive(logger: Logger = default_logger, \
        get_cache_key: Callable[[Iterable, Dict], Hashable] = lambda args, kwargs: \
            functools._make_key(args, kwargs, False),\
        get_result_after_cache_hit: Callable[[Any, Iterable, Dict], Any] = lambda value, args, kwargs: \
            value):
    def decorator(f):
        f._first_call = True
        f._cache = {}

        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            if f._first_call:
                f._new_args = args
                f._new_kwargs = kwargs
            
                try:
                    f._first_call = False
                    f._initial_key = get_cache_key(f._new_args, f._new_kwargs)
                    while True:
                        cache_key = get_cache_key(f._new_args, f._new_kwargs)
                        if cache_key in f._cache:
                            logger.log('cache hit for ' + str(cache_key))
                            return get_result_after_cache_hit(f._cache[cache_key], f._new_args, f._new_kwargs)

                        result = f(*f._new_args, **f._new_kwargs)

                        if not isinstance(result, TailRecArguments):
                            f._cache[f._initial_key] = result

                        if isinstance(result, TailRecArguments) and result.wrapped_func == f:
                            f._new_args = result.args
                            f._new_kwargs = result.kwargs
                        else:
                            break

                    return result
                finally:
                    f._first_call = True
            else:
                return TailRecArguments(f, args, kwargs)

        return wrapper
    return decorator

除了仅用于确认缓存命中的 Logger class 之外,主要区别在于每个函数现在都有一个名为 _initial_key 的新成员,用于存储密钥的第一个电话。这样,如果我调用 fact(5)5 就变成 _initial_key,结果放在 f._cache[5].

这样可以优化相互递归和tail-recursive函数,但在某些情况下效果不佳。让我们从最好的情况开始:

fact_logger = Logger('fact')
@tail_recursive(logger=fact_logger, get_cache_key=lambda args, kwargs: args[0],\
    get_result_after_cache_hit=lambda value, args, kwargs: value * args[1])
def fact(n, acc=1):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> fact(5)
    120
    >>> fact(30)
    265252859812191058636308480000000
    >>> fact_logger.print()
    [fact] - cache hit for 5
    """
    return acc if n <= 1 else fact(n - 1, acc * n)

@tail_recursive 装饰器初始化包括(记录器)get_cache_key 指定只有第一个参数 n 应该是缓存键的一部分和 get_result_after_cache_hit指定如何在之后产生最终结果缓存命中。在上面的例子中,当 fact(30) 达到 fact(5, <partial_factorial>) 时,结果立即计算为 <partial_factorial> * f._cache[5].

even-odd 也是如此,只是在这种情况下 tail_recursive 的默认参数绰绰有余:

even_logger = Logger('even')
@tail_recursive(logger=even_logger)
def even(n):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> even(100)
    True
    >>> even(101)
    False
    >>> even(104)
    True
    >>> even_logger.print()
    [even] - cache hit for 100
    """
    return True if n == 0 else odd(n - 1)

不幸的是,这不适用于斐波那契函数。您应该通过在每次调用期间打印参数来轻松说服自己,结果如下所示:

30 0 1
29 1 1
28 1 2
27 2 3
26 3 5
25 5 8
...

建立缓存键规则需要一个更复杂的逻辑,这可能会使 tail_recursive 装饰器变得非常不可读且可移植性较差。