Python 中的递归函数:获取特定嵌套项列表的最佳方法

Recursive function in Python: Best way to get a list of specific nested items

我有一棵嵌套字典树。这是一个小摘录,只是为了给你一个想法:

db = {
    'compatibility': {
        'style': {
            'path_to_file': 'compatibility/render/style.py',
            'checksum': {
                '0.0.3':'AAA55d796c25ad867bbcb8e0da4e48d17826e6f9fce',
                '0.0.2': '55d796c25ad867bbcb8e0da4e48d17826e6f9fe606',}}},
    'developer': {
        'render': {
            'installation': {
                'path_to_file': 'developer/render/installation.py',
                'checksum': {
                    '0.0.1': 'c1c0d4080e72292710ac1ce942cf59ce0e26319cf3'}},
            'tests': {
                'path_to_file': 'developer/render/test.py',
                'checksum': {
                    '0.0.1': 'e71173ac43ecd949fdb96cfb835abadb877a5233a36b115'}}}}}

我想获取树中嵌套的所有字典模块的列表。这样我就可以循环列表并测试每个文件的校验和(请注意,模块可以像上面的示例一样处于不同级别)。

为此,我编写了以下递归函数。我知道每个模块都有一个 "path_to_file" 和 "checksum" 键,所以我用它来测试字典是否是一个模块。请注意,我必须将递归函数包装在另一个包含列表的函数中,这样每次递归函数运行时列表都不会被覆盖。

def _get_modules_from_db(dictionary):
    def recursive_find(inner_dictionary):
        for k, v in inner_dictionary.iteritems():
            if (isinstance(v, dict) and
                    not sorted(v.keys()) == ['path_to_file', 'sha512sum']):
                recursive_find(v)
            else:
                leaves.append(v)
    leaves = []
    recursive_find(dictionary)
    return leaves

这种方法可行,但是必须包装函数对我来说似乎很难看。所以,我对 Stack Overflow 专业人士的问题:

您是否可以推荐更简单(或更好)的方法来实现此目的而无需包装函数?

在我个人看来,嵌套函数很好,但这里有一个更简洁的版本

from operator import add

def _get_modules_from_db(db):
  if 'path_to_file' in db and 'sha512sum' in db:
    return [db]
  return reduce(add, (_get_modules_from_db(db[m]) for m in db))

我认为这种方法没有问题。您想要一个操纵某些全局状态的递归函数 - 这是一种非常合理的方法(内部函数在 Python 中并不少见)。

也就是说,如果您想避免嵌套函数,您可以添加默认参数:

def _get_modules_from_db(db, leaves=None):
    if leaves is None:
        leaves = []
    if not isinstance(db, dict):
        return leaves

    # Use 'in' check to avoid sorting keys and doing a list compare
    if 'path_to_file' in db and 'checksum' in db:
        leaves.append(db)
    else:
        for v in db.values():
            _get_modules_from_db(v, leaves)

    return leaves

首先,您需要包装函数的唯一原因是因为您要 recursive_find 就地改变 leaves 闭包单元而不是 return 改变它。有时这是一个有用的性能优化(尽管它经常是一种悲观),有时只是不清楚如何去做,但这次是微不足道的:

def _get_modules_from_db(dictionary):
    leaves = []
    for k, v in dictionary.iteritems():
        if (isinstance(v, dict) and
            not sorted(v.keys()) == ['path_to_file', 'sha512sum']):
            leaves.extend(_get_modules_from_db(v))
        else:
            leaves.append(v)
    return leaves

对于其他改进:我可能会把它变成一个生成器(至少在 3.3+ 中,使用 yield from;在 2.7 中我可能会三思而后行)。而且,当我们这样做时,我会将键视图(在 3.x 中)或 set(v)(在 2.x 中)与一个集合进行比较,而不是进行不必要的排序(并且没有理由 .keys()setsorted),并使用 != 而不是 not==。而且,除非有充分的理由只接受 dictdict 子类,否则我要么直接输入它,要么使用 collections.[abc.]Mapping。所以:

def _get_modules_from_db(dictionary):
    for k, v in dictionary.items():
        if isinstance(v, Mapping) and v.keys() != {'path_to_file', 'sha512sum'}:
            yield from _get_modules_from_db(v)
        else:
            yield v

或者,或者,将基本情况拉出,这样您就可以直接在字符串上调用它:

def _get_modules_from_db(d):
    if isinstance(d, Mapping) and d.keys() != {'path_to_file', 'sha512sum'}:
        for v in d.values():
            yield from _get_modules_from_db(v)
    else:
        yield d

我认为这比您所拥有的更具可读性,它是 6 行而不是 11 行(尽管 2.x 版本是 7 行)。但我没有发现您的版本有任何实际问题。


如果您不确定如何将 3.3+ 代码转换为 2.7/3.2 代码:

  • yield from eggs 重写为 for egg in eggs: yield egg
  • Mappingcollections,而不是 collections.abc
  • 使用 set(v) 而不是 v.keys()
  • 可能使用 itervalues 而不是 values(仅限 2.x)。