Python: 为特定函数调用修补打印函数? (用于打印递归树的装饰器)
Python: Patch the print function for a particular function call? (Decorator for printing the recursion tree)
我写了一个装饰器来打印一些函数调用产生的递归树。
from functools import wraps
def printRecursionTree(func):
global _recursiondepth
_print = print
_recursiondepth = 0
def getpads():
if _recursiondepth == 0:
strFn = '{} └──'.format(' │ ' * (_recursiondepth-1))
strOther = '{} ▒▒'.format(' │ ' * (_recursiondepth-1))
strRet = '{} '.format(' │ ' * (_recursiondepth-1))
else:
strFn = ' {} ├──'.format(' │ ' * (_recursiondepth-1))
strOther = ' {} │▒▒'.format(' │ ' * (_recursiondepth-1))
strRet = ' {} │ '.format(' │ ' * (_recursiondepth-1))
return strFn, strRet, strOther
def indentedprint():
@wraps(print)
def wrapper(*args, **kwargs):
strFn, strRet, strOther = getpads()
_print(strOther, end=' ')
_print(*args, **kwargs)
return wrapper
@wraps(func)
def wrapper(*args, **kwargs):
global _recursiondepth
global print
strFn, strRet, strOther = getpads()
if args and kwargs:
_print(strFn, '{}({}, {}):'.format(func.__qualname__, ', '.join(args), kwargs))
else:
_print(strFn, '{}({}):'.format(func.__qualname__, ', '.join(map(str, args)) if args else '', kwargs if kwargs else ''))
_recursiondepth += 1
print, backup = indentedprint(), print
retval = func(*args, **kwargs)
print = backup
_recursiondepth -= 1
_print(strRet, '╰', retval)
if _recursiondepth == 0:
_print()
return retval
return wrapper
用法示例:
@printRecursionTree
def fib(n):
if n <= 1:
print('Base Case')
return n
print('Recursive Case')
return fib(n-1) + fib(n-2)
# This works with mutually recursive functions too,
# since the variable _recursiondepth is global
@printRecursionTree
def iseven(n):
print('checking if even')
if n == 0: return True
return isodd(n-1)
@printRecursionTree
def isodd(n):
print('checking if odd')
if n == 0: return False
return iseven(n-1)
iseven(5)
fib(5)
'''Prints:
└── iseven(5):
│▒▒ checking if even
│▒▒ Note how the print
│▒▒ statements get nicely indented
├── isodd(4):
│ │▒▒ checking if odd
│ ├── iseven(3):
│ │ │▒▒ checking if even
│ │ │▒▒ Note how the print
│ │ │▒▒ statements get nicely indented
│ │ ├── isodd(2):
│ │ │ │▒▒ checking if odd
│ │ │ ├── iseven(1):
│ │ │ │ │▒▒ checking if even
│ │ │ │ │▒▒ Note how the print
│ │ │ │ │▒▒ statements get nicely indented
│ │ │ │ ├── isodd(0):
│ │ │ │ │ │▒▒ checking if odd
│ │ │ │ │ ╰ False
│ │ │ │ ╰ False
│ │ │ ╰ False
│ │ ╰ False
│ ╰ False
╰ False
└── fib(5):
│▒▒ Recursive Case
├── fib(4):
│ │▒▒ Recursive Case
│ ├── fib(3):
│ │ │▒▒ Recursive Case
│ │ ├── fib(2):
│ │ │ │▒▒ Recursive Case
│ │ │ ├── fib(1):
│ │ │ │ │▒▒ Base Case
│ │ │ │ ╰ 1
│ │ │ ├── fib(0):
│ │ │ │ │▒▒ Base Case
│ │ │ │ ╰ 0
│ │ │ ╰ 1
│ │ ├── fib(1):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 1
│ │ ╰ 2
│ ├── fib(2):
│ │ │▒▒ Recursive Case
│ │ ├── fib(1):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 1
│ │ ├── fib(0):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 0
│ │ ╰ 1
│ ╰ 3
├── fib(3):
│ │▒▒ Recursive Case
│ ├── fib(2):
│ │ │▒▒ Recursive Case
│ │ ├── fib(1):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 1
│ │ ├── fib(0):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 0
│ │ ╰ 1
│ ├── fib(1):
│ │ │▒▒ Base Case
│ │ ╰ 1
│ ╰ 2
╰ 5
'''
这个示例代码可以正常工作只要它在定义装饰器的同一个文件中。
但是,如果从某个模块导入装饰器,打印语句将不再缩进。
我知道出现这种行为是因为装饰器修补的 print 语句对于它自己的模块是全局的,不跨模块共享。
- 我该如何解决这个问题?
- 是否有更好的方法来仅针对对另一个函数的特定调用修补一个函数?
您可以通过在 builtins
模块中替换它来更改所有模块的内置打印功能的行为。
因此,将您对全局变量 print
的赋值更改为对 builtins.print
的赋值(在导入 builtins
之后):
import builtins
...
@wraps(func)
def wrapper(*args, **kwargs):
global _recursiondepth # no more need for global print up here
strFn, strRet, strOther = getpads()
if args and kwargs:
_print(strFn, '{}({}, {}):'.format(func.__qualname__, ', '.join(args), kwargs))
else:
_print(strFn, '{}({}):'.format(func.__qualname__, ', '.join(map(str, args)) if args else '', kwargs if kwargs else ''))
_recursiondepth += 1
builtins.print, backup = indentedprint(), print # change here
retval = func(*args, **kwargs)
builtins.print = backup # and here
_recursiondepth -= 1
_print(strRet, '╰', retval)
if _recursiondepth == 0:
_print()
return retval
我写了一个装饰器来打印一些函数调用产生的递归树。
from functools import wraps
def printRecursionTree(func):
global _recursiondepth
_print = print
_recursiondepth = 0
def getpads():
if _recursiondepth == 0:
strFn = '{} └──'.format(' │ ' * (_recursiondepth-1))
strOther = '{} ▒▒'.format(' │ ' * (_recursiondepth-1))
strRet = '{} '.format(' │ ' * (_recursiondepth-1))
else:
strFn = ' {} ├──'.format(' │ ' * (_recursiondepth-1))
strOther = ' {} │▒▒'.format(' │ ' * (_recursiondepth-1))
strRet = ' {} │ '.format(' │ ' * (_recursiondepth-1))
return strFn, strRet, strOther
def indentedprint():
@wraps(print)
def wrapper(*args, **kwargs):
strFn, strRet, strOther = getpads()
_print(strOther, end=' ')
_print(*args, **kwargs)
return wrapper
@wraps(func)
def wrapper(*args, **kwargs):
global _recursiondepth
global print
strFn, strRet, strOther = getpads()
if args and kwargs:
_print(strFn, '{}({}, {}):'.format(func.__qualname__, ', '.join(args), kwargs))
else:
_print(strFn, '{}({}):'.format(func.__qualname__, ', '.join(map(str, args)) if args else '', kwargs if kwargs else ''))
_recursiondepth += 1
print, backup = indentedprint(), print
retval = func(*args, **kwargs)
print = backup
_recursiondepth -= 1
_print(strRet, '╰', retval)
if _recursiondepth == 0:
_print()
return retval
return wrapper
用法示例:
@printRecursionTree
def fib(n):
if n <= 1:
print('Base Case')
return n
print('Recursive Case')
return fib(n-1) + fib(n-2)
# This works with mutually recursive functions too,
# since the variable _recursiondepth is global
@printRecursionTree
def iseven(n):
print('checking if even')
if n == 0: return True
return isodd(n-1)
@printRecursionTree
def isodd(n):
print('checking if odd')
if n == 0: return False
return iseven(n-1)
iseven(5)
fib(5)
'''Prints:
└── iseven(5):
│▒▒ checking if even
│▒▒ Note how the print
│▒▒ statements get nicely indented
├── isodd(4):
│ │▒▒ checking if odd
│ ├── iseven(3):
│ │ │▒▒ checking if even
│ │ │▒▒ Note how the print
│ │ │▒▒ statements get nicely indented
│ │ ├── isodd(2):
│ │ │ │▒▒ checking if odd
│ │ │ ├── iseven(1):
│ │ │ │ │▒▒ checking if even
│ │ │ │ │▒▒ Note how the print
│ │ │ │ │▒▒ statements get nicely indented
│ │ │ │ ├── isodd(0):
│ │ │ │ │ │▒▒ checking if odd
│ │ │ │ │ ╰ False
│ │ │ │ ╰ False
│ │ │ ╰ False
│ │ ╰ False
│ ╰ False
╰ False
└── fib(5):
│▒▒ Recursive Case
├── fib(4):
│ │▒▒ Recursive Case
│ ├── fib(3):
│ │ │▒▒ Recursive Case
│ │ ├── fib(2):
│ │ │ │▒▒ Recursive Case
│ │ │ ├── fib(1):
│ │ │ │ │▒▒ Base Case
│ │ │ │ ╰ 1
│ │ │ ├── fib(0):
│ │ │ │ │▒▒ Base Case
│ │ │ │ ╰ 0
│ │ │ ╰ 1
│ │ ├── fib(1):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 1
│ │ ╰ 2
│ ├── fib(2):
│ │ │▒▒ Recursive Case
│ │ ├── fib(1):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 1
│ │ ├── fib(0):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 0
│ │ ╰ 1
│ ╰ 3
├── fib(3):
│ │▒▒ Recursive Case
│ ├── fib(2):
│ │ │▒▒ Recursive Case
│ │ ├── fib(1):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 1
│ │ ├── fib(0):
│ │ │ │▒▒ Base Case
│ │ │ ╰ 0
│ │ ╰ 1
│ ├── fib(1):
│ │ │▒▒ Base Case
│ │ ╰ 1
│ ╰ 2
╰ 5
'''
这个示例代码可以正常工作只要它在定义装饰器的同一个文件中。
但是,如果从某个模块导入装饰器,打印语句将不再缩进。
我知道出现这种行为是因为装饰器修补的 print 语句对于它自己的模块是全局的,不跨模块共享。
- 我该如何解决这个问题?
- 是否有更好的方法来仅针对对另一个函数的特定调用修补一个函数?
您可以通过在 builtins
模块中替换它来更改所有模块的内置打印功能的行为。
因此,将您对全局变量 print
的赋值更改为对 builtins.print
的赋值(在导入 builtins
之后):
import builtins
...
@wraps(func)
def wrapper(*args, **kwargs):
global _recursiondepth # no more need for global print up here
strFn, strRet, strOther = getpads()
if args and kwargs:
_print(strFn, '{}({}, {}):'.format(func.__qualname__, ', '.join(args), kwargs))
else:
_print(strFn, '{}({}):'.format(func.__qualname__, ', '.join(map(str, args)) if args else '', kwargs if kwargs else ''))
_recursiondepth += 1
builtins.print, backup = indentedprint(), print # change here
retval = func(*args, **kwargs)
builtins.print = backup # and here
_recursiondepth -= 1
_print(strRet, '╰', retval)
if _recursiondepth == 0:
_print()
return retval