Python: 为特定函数调用修补打印函数? (用于打印递归树的装饰器)

Python: Patch the print function for a particular function call? (Decorator for printing the recursion tree)

我写了一个装饰器来打印一些函数调用产生的递归树。

from functools import wraps

def printRecursionTree(func):
    global _recursiondepth
    _print = print
    _recursiondepth = 0

    def getpads():
        if _recursiondepth == 0:
            strFn    = '{} └──'.format(' │  ' * (_recursiondepth-1))
            strOther = '{}  ▒▒'.format(' │  ' * (_recursiondepth-1))
            strRet   = '{}    '.format(' │  ' * (_recursiondepth-1))
        else:
            strFn    = '    {} ├──'.format(' │  ' * (_recursiondepth-1))
            strOther = '    {} │▒▒'.format(' │  ' * (_recursiondepth-1))
            strRet   = '    {} │  '.format(' │  ' * (_recursiondepth-1))

        return strFn, strRet, strOther

    def indentedprint():
        @wraps(print)
        def wrapper(*args, **kwargs):
            strFn, strRet, strOther = getpads()
            _print(strOther, end=' ')
            _print(*args, **kwargs)
        return wrapper


    @wraps(func)
    def wrapper(*args, **kwargs):
        global _recursiondepth
        global print

        strFn, strRet, strOther = getpads()

        if args and kwargs:
            _print(strFn, '{}({}, {}):'.format(func.__qualname__, ', '.join(args), kwargs))
        else:
            _print(strFn, '{}({}):'.format(func.__qualname__, ', '.join(map(str, args)) if args else '', kwargs if kwargs else ''))
        _recursiondepth += 1
        print, backup = indentedprint(), print
        retval = func(*args, **kwargs)
        print = backup
        _recursiondepth -= 1
        _print(strRet, '╰', retval)
        if _recursiondepth == 0:
            _print()
        return retval

    return wrapper

用法示例:

@printRecursionTree
def fib(n):
    if n <= 1:
        print('Base Case')
        return n
    print('Recursive Case')
    return fib(n-1) + fib(n-2)

# This works with mutually recursive functions too,
# since the variable _recursiondepth is global
@printRecursionTree
def iseven(n):
    print('checking if even')
    if n == 0: return True
    return isodd(n-1)

@printRecursionTree
def isodd(n):
    print('checking if odd')
    if n == 0: return False
    return iseven(n-1)

iseven(5)
fib(5)

'''Prints:

└── iseven(5):
     │▒▒ checking if even
     │▒▒ Note how the print
     │▒▒ statements get nicely indented
     ├── isodd(4):
     │   │▒▒ checking if odd
     │   ├── iseven(3):
     │   │   │▒▒ checking if even
     │   │   │▒▒ Note how the print
     │   │   │▒▒ statements get nicely indented
     │   │   ├── isodd(2):
     │   │   │   │▒▒ checking if odd
     │   │   │   ├── iseven(1):
     │   │   │   │   │▒▒ checking if even
     │   │   │   │   │▒▒ Note how the print
     │   │   │   │   │▒▒ statements get nicely indented
     │   │   │   │   ├── isodd(0):
     │   │   │   │   │   │▒▒ checking if odd
     │   │   │   │   │   ╰ False
     │   │   │   │   ╰ False
     │   │   │   ╰ False
     │   │   ╰ False
     │   ╰ False
     ╰ False

 └── fib(5):
     │▒▒ Recursive Case
     ├── fib(4):
     │   │▒▒ Recursive Case
     │   ├── fib(3):
     │   │   │▒▒ Recursive Case
     │   │   ├── fib(2):
     │   │   │   │▒▒ Recursive Case
     │   │   │   ├── fib(1):
     │   │   │   │   │▒▒ Base Case
     │   │   │   │   ╰ 1
     │   │   │   ├── fib(0):
     │   │   │   │   │▒▒ Base Case
     │   │   │   │   ╰ 0
     │   │   │   ╰ 1
     │   │   ├── fib(1):
     │   │   │   │▒▒ Base Case
     │   │   │   ╰ 1
     │   │   ╰ 2
     │   ├── fib(2):
     │   │   │▒▒ Recursive Case
     │   │   ├── fib(1):
     │   │   │   │▒▒ Base Case
     │   │   │   ╰ 1
     │   │   ├── fib(0):
     │   │   │   │▒▒ Base Case
     │   │   │   ╰ 0
     │   │   ╰ 1
     │   ╰ 3
     ├── fib(3):
     │   │▒▒ Recursive Case
     │   ├── fib(2):
     │   │   │▒▒ Recursive Case
     │   │   ├── fib(1):
     │   │   │   │▒▒ Base Case
     │   │   │   ╰ 1
     │   │   ├── fib(0):
     │   │   │   │▒▒ Base Case
     │   │   │   ╰ 0
     │   │   ╰ 1
     │   ├── fib(1):
     │   │   │▒▒ Base Case
     │   │   ╰ 1
     │   ╰ 2
     ╰ 5
'''

这个示例代码可以正常工作只要它在定义装饰器的同一个文件中

但是,如果从某个模块导入装饰器,打印语句将不再缩进。

我知道出现这种行为是因为装饰器修补的 print 语句对于它自己的模块是全局的,不跨模块共享。

  1. 我该如何解决这个问题?
  2. 是否有更好的方法来仅针对对另一个函数的特定调用修补一个函数?

您可以通过在 builtins 模块中替换它来更改所有模块的内置打印功能的行为。

因此,将您对全局变量 print 的赋值更改为对 builtins.print 的赋值(在导入 builtins 之后):

import builtins

...

    @wraps(func)
    def wrapper(*args, **kwargs):
        global _recursiondepth # no more need for global print up here

        strFn, strRet, strOther = getpads()

        if args and kwargs:
            _print(strFn, '{}({}, {}):'.format(func.__qualname__, ', '.join(args), kwargs))
        else:
            _print(strFn, '{}({}):'.format(func.__qualname__, ', '.join(map(str, args)) if args else '', kwargs if kwargs else ''))
        _recursiondepth += 1
        builtins.print, backup = indentedprint(), print   # change here
        retval = func(*args, **kwargs)
        builtins.print = backup                           # and here
        _recursiondepth -= 1
        _print(strRet, '╰', retval)
        if _recursiondepth == 0:
            _print()
        return retval