Python 3.x:如何使用ast搜索打印语句

Python 3.x: how to use ast to search for a print statement

我正在创建一个测试,它应该检查函数是否包含 print 语句(Python 3.x,我使用的是 3.7.4)。我一直在使用 ast 来检查类似的事情(参考 问题中的答案),例如 return 或列表理解,但我陷入了 print.

一个 online AST explorer 在正文中列出了一个 Print 子类,它需要 Python 3 print 秒,所以我知道它不是 Python 2 件事。

Green Tree Snakes ast 文档说 Print 在 Python 2 中只有一个 ast 节点。这更接近我所经历的。这是我要用来做断言的函数:

def printsSomething(func):
    return any(isinstance(node, ast.Print) for node in ast.walk(ast.parse(inspect.getsource(func))))

returns:

TypeError: isinstance() arg 2 must be a type or tuple of types

我假设这与 print 是 Python 3.x 中的一个函数有关,但我不知道如何利用这些知识来发挥我的优势。我将如何使用 ast 来查明是否已调用 print

我想重申一下,我已经让这段代码适用于其他 ast 节点,例如 return,所以我应该确信这不是我的代码特有的错误。

谢谢!

print 是 python 3 中的一个函数,因此您需要检查一个 ast.Expr,它包含一个 ast.Call,而 ast.Name 具有编号 print.

这是一个简单的函数:

def bar(x: str) -> None:
    string = f"Hello {x}!"  # ast.Assign
    print(string)           # ast.Expr

这是完整的 ast 转储:

Module(body=[FunctionDef(name='bar', args=arguments(args=[arg(arg='x', annotation=Name(id='str', ctx=Load()))], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[]), body=[Assign(targets=[Name(id='string', ctx=Store())], value=JoinedStr(values=[Str(s='Hello '), FormattedValue(value=Name(id='x', ctx=Load()), conversion=-1, format_spec=None), Str(s='!')])), Expr(value=Call(func=Name(id='print', ctx=Load()), args=[Name(id='string', ctx=Load())], keywords=[]))], decorator_list=[], returns=NameConstant(value=None))])

相关部分(print)是:

Expr(value=Call(func=Name(id='print', ctx=Load())

下面是一个带有节点访问者的简单示例(sublcassing ast.NodeVisitor):

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import ast
import inspect
from typing import Callable


class MyNodeVisitor(ast.NodeVisitor):
    def visit_Expr(self, node: ast.Expr):
        """Called when the visitor visits an ast.Expr"""
        print(f"Found expression node at: line: {node.lineno}; col: {node.col_offset}")

        # check "value" which must be an instance of "Call" for a 'print'
        if not isinstance(node.value, ast.Call):
            return

        # now check the function itself.
        func = node.value.func  # ast.Name
        if func.id == "print":
            print("found a print")


def contains_print(f: Callable):
    source = inspect.getsource(f)
    node = ast.parse(source)
    func_name = [_def.name for _def in node.body if isinstance(_def, ast.FunctionDef)][0]
    print(f"{'-' * 79}\nvisiting function: {func_name}")
    print(f"node dump: {ast.dump(node)}")
    node_visitor = MyNodeVisitor()
    node_visitor.visit(node)


def foo(x: int) -> int:
    return x + 1

def bar(x: str) -> None:
    string = f"Hello {x}!"  # ast.Assign
    print(string)           # ast.Expr

def baz(x: float) -> float:
    if x == 0.0:
        print("oh noes!")
        raise ValueError

    return 10 / x


if __name__ == "__main__":
    contains_print(bar)
    contains_print(foo)
    contains_print(baz)

这里是输出(减去 ast 转储):

-------------------------------------------------------------------------------
visiting function: bar
Found expression node at: line: 3; col: 4
found a print
-------------------------------------------------------------------------------
visiting function: foo
-------------------------------------------------------------------------------
visiting function: baz
Found expression node at: line: 3; col: 8
found a print

如果调用任何命名对象(包括 print 等函数),那么在您的 node 中将至少有一个 _ast.Name 对象。对象的名称 ('print') 存储在该节点的 id 属性下。

我相信您已经知道,print 在 python 版本 2 和版本 3 之间从语句更改为函数,这可能解释了您 运行 遇到问题的原因.

尝试以下操作:

import ast
import inspect

def do_print():
    print('hello')

def dont_print():
    pass

def prints_something(func):
    is_print = False
    for node in ast.walk(ast.parse(inspect.getsource(func))):
        try:
            is_print = (node.id == 'print')
        except AttributeError:  # only expect id to exist for Name objs
            pass
        if is_print:
            break
    return is_print

prints_something(do_print), prints_something(dont_print)

>>> True, False

...或者如果您喜欢单行代码(其中 func 是您要测试的函数):

any(hasattr(node,'id') and node.id == 'print' 
    for node in ast.walk(ast.parse(inspect.getsource(func))))