通过查询内联替换 AST 节点
Inline replace of AST node by query
输入如下:
query = ("class_name", "function_name", "arg_name")
如何将找到的内容替换为提供的其他内容node
?
前阶段的解析示例:
class Foo(object):
def f(self, g: str = "foo"): pass
后阶段解析示例:
class Foo(object):
def f(self, g: int = 5): pass
给定以下假设函数调用:
replace_ast_node(
query=("Foo", "f", "g"),
node=ast.parse("class Foo(object):\n def f(self, g: str = 'foo'): pass"),
# Use `AnnAssign` over `arg`; as `defaults` is higher in the `FunctionDef`
replace_with=AnnAssign(
annotation=Name(ctx=Load(), id="int"),
simple=1,
target=Name(ctx=Store(), id="g"),
value=Constant(kind=None, value=5),
),
)
我一起黑过a simple solution for finding a node with the query list, which has the added benefit of working for anything ("Foo", "f", "g")
could refer to def Foo(): def f(): def g():
, as well as a parser/emitter from arg
to AnnAssign
. But I can't figure this stage out; does ast.NodeTransformer
按顺序遍历? - 那么我是否应该不断遍历、附加当前名称并检查当前位置是否是完整的查询字符串? - 我觉得我缺少一些干净的解决方案…
我决定把它分成两个问题。首先,让每个节点都知道它在宇宙中的位置:
def annotate_ancestry(node):
"""
Look to your roots. Find the child; find the parent.
Sets _location attribute to every child node.
:param node: AST node. Will be annotated in-place.
:type node: ```ast.AST```
"""
node._location = [node.name] if hasattr(node, 'name') else []
parent_location = []
for _node in ast.walk(node):
name = [_node.name] if hasattr(_node, 'name') else []
for child in ast.iter_child_nodes(_node):
if hasattr(child, 'name'):
child._location = name + [child.name]
parent_location = child._location
elif isinstance(child, ast.arg):
child._location = parent_location + [child.arg]
然后实现前面提到的一种方法ast.NodeTransformer
:
class RewriteAtQuery(ast.NodeTransformer):
"""
Replace the node at query with given node
:ivar search: Search query, e.g., ['class_name', 'method_name', 'arg_name']
:ivar replacement_node: Node to replace this search
"""
def __init__(self, search, replacement_node):
"""
:param search: Search query
:type search: ```List[str]```
:param replacement_node: Node to replace this search
:type replacement_node: ```ast.AST```
"""
self.search = search
self.replacement_node = replacement_node
self.replaced = False
def generic_visit(self, node):
"""
Visit every node, replace once, and only if found
:param node: AST node
:type node: ```ast.AST```
:returns: AST node, potentially edited
:rtype: ```ast.AST```
"""
if not self.replaced and hasattr(node, '_location') \
and node._location == self.search:
node = self.replacement_node
self.replaced = True
return ast.NodeTransformer.generic_visit(self, node)
大功告成=)
Usage/test:
parsed_ast = ast.parse(class_with_method_and_body_types_str)
annotate_ancestry(parsed_ast)
rewrite_at_query = RewriteAtQuery(
search="C.method_name.dataset_name".split("."),
replacement_node=arg(
annotation=Name(ctx=Load(), id="int"),
arg="dataset_name",
type_comment=None,
),
).visit(parsed_ast)
self.assertTrue(rewrite_at_query.replaced, True)
# Additional test to compare AST produced with desired AST [see repo]
输入如下:
query = ("class_name", "function_name", "arg_name")
如何将找到的内容替换为提供的其他内容node
?
前阶段的解析示例:
class Foo(object):
def f(self, g: str = "foo"): pass
后阶段解析示例:
class Foo(object):
def f(self, g: int = 5): pass
给定以下假设函数调用:
replace_ast_node(
query=("Foo", "f", "g"),
node=ast.parse("class Foo(object):\n def f(self, g: str = 'foo'): pass"),
# Use `AnnAssign` over `arg`; as `defaults` is higher in the `FunctionDef`
replace_with=AnnAssign(
annotation=Name(ctx=Load(), id="int"),
simple=1,
target=Name(ctx=Store(), id="g"),
value=Constant(kind=None, value=5),
),
)
我一起黑过a simple solution for finding a node with the query list, which has the added benefit of working for anything ("Foo", "f", "g")
could refer to def Foo(): def f(): def g():
, as well as a parser/emitter from arg
to AnnAssign
. But I can't figure this stage out; does ast.NodeTransformer
按顺序遍历? - 那么我是否应该不断遍历、附加当前名称并检查当前位置是否是完整的查询字符串? - 我觉得我缺少一些干净的解决方案…
我决定把它分成两个问题。首先,让每个节点都知道它在宇宙中的位置:
def annotate_ancestry(node):
"""
Look to your roots. Find the child; find the parent.
Sets _location attribute to every child node.
:param node: AST node. Will be annotated in-place.
:type node: ```ast.AST```
"""
node._location = [node.name] if hasattr(node, 'name') else []
parent_location = []
for _node in ast.walk(node):
name = [_node.name] if hasattr(_node, 'name') else []
for child in ast.iter_child_nodes(_node):
if hasattr(child, 'name'):
child._location = name + [child.name]
parent_location = child._location
elif isinstance(child, ast.arg):
child._location = parent_location + [child.arg]
然后实现前面提到的一种方法ast.NodeTransformer
:
class RewriteAtQuery(ast.NodeTransformer):
"""
Replace the node at query with given node
:ivar search: Search query, e.g., ['class_name', 'method_name', 'arg_name']
:ivar replacement_node: Node to replace this search
"""
def __init__(self, search, replacement_node):
"""
:param search: Search query
:type search: ```List[str]```
:param replacement_node: Node to replace this search
:type replacement_node: ```ast.AST```
"""
self.search = search
self.replacement_node = replacement_node
self.replaced = False
def generic_visit(self, node):
"""
Visit every node, replace once, and only if found
:param node: AST node
:type node: ```ast.AST```
:returns: AST node, potentially edited
:rtype: ```ast.AST```
"""
if not self.replaced and hasattr(node, '_location') \
and node._location == self.search:
node = self.replacement_node
self.replaced = True
return ast.NodeTransformer.generic_visit(self, node)
大功告成=)
Usage/test:
parsed_ast = ast.parse(class_with_method_and_body_types_str)
annotate_ancestry(parsed_ast)
rewrite_at_query = RewriteAtQuery(
search="C.method_name.dataset_name".split("."),
replacement_node=arg(
annotation=Name(ctx=Load(), id="int"),
arg="dataset_name",
type_comment=None,
),
).visit(parsed_ast)
self.assertTrue(rewrite_at_query.replaced, True)
# Additional test to compare AST produced with desired AST [see repo]