如何将 ast.NodeTransformer 应用于导入?
How do I apply a ast.NodeTransformer to imports?
我不仅想将 NodeTransformer 应用到当前文件的 AST,而且还想应用到任何导入的代码。如果您 运行 下面的代码,您会注意到转换器可以工作,但仅适用于读取和解析的单个文件。我将如何修改此代码以将转换器应用于已解析代码中的任何导入?
a.py:
from b import q
def r(a):
return q(a) + 5
b.py:
def q(n):
return r(n + 1)
def r(n):
return n
Main.py:
import ast
import astor
class trivial_transformer(ast.NodeTransformer):
def visit_FunctionDef(self, node):
body = []
for line in node.body:
body.append(
ast.Expr(
ast.Call(func=ast.Name('print', ctx=ast.Load()),
args=[ast.Str(s="Doing: "+astor.to_source(line).strip())],
keywords=[])))
body.append(self.generic_visit(line))
node.body = body
return node
parsed_ast = ast.fix_missing_locations(trivial_transformer().visit(ast.parse(open('a.py','r').read())))
g = {}
eval(compile(parsed_ast, '<source>', 'exec'), g)
print(g['r'](5))
这产生:
Doing: return q(a) + 5
11
但我希望它产生:
Doing: return q(a) + 5
Doing: return r(n + 1)
Doing: return n
11
好吧,花了一些时间,但我明白了 (woo):
import ast
import astor
import importlib
import sys
class trivial_transformer(ast.NodeTransformer):
def processImport(self, imp):
if imp not in sys.modules:
spec = importlib.util.find_spec(imp)
helper = importlib.util.module_from_spec(spec)
parsed_dep = ast.fix_missing_locations(self.visit(ast.parse(spec.loader.get_source(imp))))
exec(compile(parsed_dep, imp, 'exec'), helper.__dict__)
sys.modules[imp] = helper
def visit_ImportFrom(self, node):
self.processImport(node.module)
return node
def visit_Import(self, node):
for i in node.names:
self.processImport(i.name)
return node
def visit_FunctionDef(self, node):
body = []
for line in node.body:
body.append(
ast.Expr(
ast.Call(func=ast.Name('print', ctx=ast.Load()),
args=[ast.Str(s="Doing: "+astor.to_source(line).strip())],
keywords=[])))
body.append(self.generic_visit(line))
node.body = body
return node
init = 'a'
trivial_transformer().processImport(init)
import a
a.r(5)
我不仅想将 NodeTransformer 应用到当前文件的 AST,而且还想应用到任何导入的代码。如果您 运行 下面的代码,您会注意到转换器可以工作,但仅适用于读取和解析的单个文件。我将如何修改此代码以将转换器应用于已解析代码中的任何导入?
a.py:
from b import q
def r(a):
return q(a) + 5
b.py:
def q(n):
return r(n + 1)
def r(n):
return n
Main.py:
import ast
import astor
class trivial_transformer(ast.NodeTransformer):
def visit_FunctionDef(self, node):
body = []
for line in node.body:
body.append(
ast.Expr(
ast.Call(func=ast.Name('print', ctx=ast.Load()),
args=[ast.Str(s="Doing: "+astor.to_source(line).strip())],
keywords=[])))
body.append(self.generic_visit(line))
node.body = body
return node
parsed_ast = ast.fix_missing_locations(trivial_transformer().visit(ast.parse(open('a.py','r').read())))
g = {}
eval(compile(parsed_ast, '<source>', 'exec'), g)
print(g['r'](5))
这产生:
Doing: return q(a) + 5
11
但我希望它产生:
Doing: return q(a) + 5
Doing: return r(n + 1)
Doing: return n
11
好吧,花了一些时间,但我明白了 (woo):
import ast
import astor
import importlib
import sys
class trivial_transformer(ast.NodeTransformer):
def processImport(self, imp):
if imp not in sys.modules:
spec = importlib.util.find_spec(imp)
helper = importlib.util.module_from_spec(spec)
parsed_dep = ast.fix_missing_locations(self.visit(ast.parse(spec.loader.get_source(imp))))
exec(compile(parsed_dep, imp, 'exec'), helper.__dict__)
sys.modules[imp] = helper
def visit_ImportFrom(self, node):
self.processImport(node.module)
return node
def visit_Import(self, node):
for i in node.names:
self.processImport(i.name)
return node
def visit_FunctionDef(self, node):
body = []
for line in node.body:
body.append(
ast.Expr(
ast.Call(func=ast.Name('print', ctx=ast.Load()),
args=[ast.Str(s="Doing: "+astor.to_source(line).strip())],
keywords=[])))
body.append(self.generic_visit(line))
node.body = body
return node
init = 'a'
trivial_transformer().processImport(init)
import a
a.r(5)