删除 Python 源代码中未使用的变量

Remove unused variables in Python source code

问题

是否有一个简单的算法来确定变量是否在给定范围内“使用”?

在 Python AST 中,我想删除给定范围内所有未在任何地方使用的变量的赋值。


详情

励志例子

在下面的代码中,对我(人类)来说很明显,_hy_anon_var_1 没有被使用,因此可以删除 _hy_anon_var_1 = None 语句而不改变结果:

# Before
def hailstone_sequence(n: int) -> Iterable[int]:
    while n != 1:
        if 0 == n % 2:
            n //= 2
            _hy_anon_var_1 = None
        else:
            n = 3 * n + 1
            _hy_anon_var_1 = None
        yield n

# After
def hailstone_sequence(n: int) -> Iterable[int]:
    while n != 1:
        if 0 == n % 2:
            n //= 2
        else:
            n = 3 * n + 1
        yield n

奖励版本

将其扩展到 []-以字符串文字作为键的查找。

在这个例子中,我希望 _hyx_letXUffffX25['x'] 被删除为未使用,因为 _hyx_letXUffffX25 对于 h 是本地的,所以 _hyx_letXUffffX25['x'] 本质上是相同的东西局部变量。然后我希望 _hyx_letXUffffX25 本身在没有更多引用后被删除。

# Before
def h():
    _hyx_letXUffffX25 = {}
    _hyx_letXUffffX25['x'] = 5
    return 3

# After
def h():
    return 3

据我所知,这有点像边缘情况,我认为基本的算法问题是相同的。

“二手”的定义

假设代码中没有使用动态名称查找。

如果在给定范围内满足以下任一条件,则名称已使用

  1. 它在表达式中的任何地方被引用。示例包括:return 语句中的表达式、赋值语句右侧的表达式、函数定义中的默认参数、在局部函数定义中被引用等。
  2. 它在"augmented assignment" statement, i.e. it is an augtarget 的左侧引用。这可能代表许多程序中的“无用工作”,但就此任务而言,这与完全未使用的名称不同。
  3. nonlocalglobal。这些可能是无用的非局部变量或全局变量,但因为它们超出了给定的范围,所以就我的目的而言,假设它们是“已使用的”是可以的。

如果这看起来不正确,或者您认为我遗漏了什么,请在评论中告诉我。

“已使用”和“未使用”的示例

示例 1:未使用

f 中的变量 i 未使用:

def f():
    i = 0
    return 5

示例 2:未使用

f 中的变量 x 未使用:

def f():
    def g(x):
        return x/5
    x = 10
    return g(100)

名称 x 确实出现在 g 中,但 g 中的变量 xg 的本地变量。它隐藏了在 f 中创建的变量 x,但是两个 x 名称不是同一个变量。

变化

如果g没有参数x,那么实际上使用的是x

def f():
    x = 10
    def g():
        return x/5
    return g(100)

示例 3:已使用

f 中的变量 i 使用:

def f():
    i = 0
    return i

示例 4:使用

silly_mapsilly_sum 中的变量 accum 在两个示例中 使用 :

def silly_map(func, data):
    data = iter(data)
    accum = []

    def _impl():
        try:
            value = next(data)
        except StopIteration:
            return accum
        else:
            accum.append(value)
            return _impl()

    return _impl()
def silly_any(func, data):
    data = iter(data)
    accum = False

    def _impl():
        nonlocal accum, data
        try:
            value = next(data)
        except StopIteration:
            return accum
        else:
            if value:
                data = []
                accum = True
            else:
                return _impl()

    return _impl()

下面的解决方案分为两部分。首先,遍历源的语法树,发现所有未使用的目标赋值语句。其次,通过自定义 ast.NodeTransformer class 再次遍历树,删除这些有问题的赋值语句。重复该过程,直到删除所有未使用的赋值语句。一旦完成,最终的源代码就写出来了。

ast遍历器class:

import ast, itertools, collections as cl
class AssgnCheck:
   def __init__(self, scopes = None):
      self.scopes = scopes or cl.defaultdict(list)
   @classmethod
   def eq_ast(cls, a1, a2):
      #check that two `ast`s are the same
      if type(a1) != type(a2):
         return False
      if isinstance(a1, list):
         return all(cls.eq_ast(*i) for i in itertools.zip_longest(a1, a2))
      if not isinstance(a1, ast.AST):
         return a1 == a2
      return all(cls.eq_ast(getattr(a1, i, None), getattr(a2, i, None)) 
                 for i in set(a1._fields)|set(a2._fields) if i != 'ctx')
   def check_exist(self, t_ast, s_path):
      #traverse the scope stack and remove scope assignments that are discovered in the `ast`
      s_scopes = []
      for _ast in t_ast:
         for sid in s_path[::-1]:
            s_scopes.extend(found:=[b for _, b in self.scopes[sid] if AssgnCheck.eq_ast(_ast, b) and \
                all(not AssgnCheck.eq_ast(j, b) for j in s_scopes)])
            self.scopes[sid] = [(a, b) for a, b in self.scopes[sid] if b not in found]
   def traverse(self, _ast, s_path = [1]):
      #walk the ast object itself
      _t_ast = None
      if isinstance(_ast, ast.Assign): #if assignment statement, add ast object to current scope
         self.traverse(_ast.targets[0], s_path)
         self.scopes[s_path[-1]].append((True, _ast.targets[0]))
         _ast = _ast.value
      if isinstance(_ast, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
         s_path = [*s_path, (nid:=(1 if not self.scopes else max(self.scopes)+1))]
         if isinstance(_ast, (ast.FunctionDef, ast.AsyncFunctionDef)):
            self.scopes[nid].extend([(False, ast.Name(i.arg)) for i in _ast.args.args])
            _t_ast = [*_ast.args.defaults, *_ast.body]
      self.check_exist(_t_ast if _t_ast is not None else [_ast], s_path) #determine if any assignment statement targets have previously defined names
      if _t_ast is None:
         for _b in _ast._fields:
            if isinstance((b:=getattr(_ast, _b)), list):
               for i in b:
                  self.traverse(i, s_path)
            elif isinstance(b, ast.AST):
               self.traverse(b, s_path)
      else:
          for _ast in _t_ast:
             self.traverse(_ast, s_path)
         

综合起来:

class Visit(ast.NodeTransformer):
   def __init__(self, asgn):
       super().__init__()
       self.asgn = asgn
   def visit_Assign(self, node):
       #remove assignment nodes marked as unused
       if any(node.targets[0] == i for i in self.asgn):
          return None
       return node

def remove_assgn(f_name):
  tree = ast.parse(open(f_name).read())
  while True:
     r = AssgnCheck()
     r.traverse(tree)
     if not (k:=[j for b in r.scopes.values() for k, j in b if k]):
        break
     v = Visit(k)
     tree = v.visit(tree)
  return ast.unparse(tree)

print(remove_assgn('test_name_assign.py'))

输出样本

test_name_assign.py 的内容:

def hailstone_sequence(n: int) -> Iterable[int]:
    while n != 1:
        if 0 == n % 2:
            n //= 2
            _hy_anon_var_1 = None
        else:
            n = 3 * n + 1
            _hy_anon_var_1 = None
        yield n

输出:

def hailstone_sequence(n: int) -> Iterable[int]:
    while n != 1:
        if 0 == n % 2:
            n //= 2
        else:
            n = 3 * n + 1
        yield n

test_name_assign.py 的内容:

def h():
    _hyx_letXUffffX25 = {}
    _hyx_letXUffffX25['x'] = 5
    return 3

输出:

def h():
    return 3

test_name_assign.py 的内容:

def f():
    i = 0
    return 5

输出:

def f():
    return 5

test_name_assign.py 的内容:

def f():
    x = 10
    def g():
        return x/5
    return g(100)

输出:

def f():
    x = 10
    def g():
        return x / 5
    return g(100)