如何在缓存结果时更新非局部变量?

How to update nonlocal variables while caching results?

当使用像lru_cache这样的functools缓存函数时,内部函数不会更新非局部变量的值。同样的方法在没有装饰器的情况下也有效。

使用缓存装饰器时,非局部变量不会更新吗?另外,如果我必须更新非局部变量但还要存储结果以避免重复工作怎么办?或者我是否需要 return 来自缓存函数的答案?

例如。以下未正确更新非局部变量的值

def foo(x):
    outer_var=0

    @lru_cache
    def bar(i):
        nonlocal outer_var
        if condition:
            outer_var+=1
        else:
            bar(i+1)

    bar(x)
    return outer_var

背景

我正在尝试 Decode Ways problem,它正在寻找可以将一串数字解释为字母的方式的数量。我从第一个字母开始,采取一两个步骤来检查它们是否有效。到达字符串末尾时,我更新了一个非局部变量,它存储了可能的方法数。此方法在不使用 lru_cache 的情况下给出正确答案,但在使用缓存时失败。我 return 值有效但我想检查如何在使用记忆装饰器时更新非局部变量的另一种方法。

我的错误代码:

ways=0
@lru_cache(None) # works well without this
def recurse(i):
    nonlocal ways
    if i==len(s):
        ways+=1
    elif i<len(s):
        if 1<=int(s[i])<=9:
            recurse(i+1)
        if i+2<=len(s) and 10<=int(s[i:i+2])<=26:
            recurse(i+2)
    return 

recurse(0)
return ways

接受的解决方案:

@lru_cache(None)
def recurse(i):
    if i==len(s):
        return 1

    elif i<len(s):
        ans=0
        if 1<=int(s[i])<=9:
            ans+= recurse(i+1)
        if i+2<=len(s) and 10<=int(s[i:i+2])<=26:
            ans+= recurse(i+2)
        return ans

return recurse(0)

lru_cache 没有什么特别之处,一个 nonlocal 变量或递归本身会导致这里的任何固有问题。该问题纯粹是逻辑问题,而不是行为异常。请参阅这个最小示例:

from functools import lru_cache

def foo():
    c = 0

    @lru_cache(None)
    def bar(i=0):
        nonlocal c

        if i < 5:
            c += 1
            bar(i + 1)

    bar()
    return c

print(foo()) # => 5

解码方式代码的缓存版本中的问题是由于递归调用的重叠性质。缓存阻止基本情况调用 recurse(i) 其中 i == len(s) 执行多次,即使它是从不同的递归路径到达的。

建立这一点的一个好方法是在基本情况(if i == len(s) 分支)中打一个 print("hello"),然后给它一个相当大的问题。您会看到 print("hello") 触发一次,并且仅触发一次,并且由于 waysi == len(s) 时无法通过 recurse(i) 以外的任何其他方式更新,因此您只剩下 ways == 1 当一切都说完了。

在上面的玩具示例中,只有一个递归路径:调用扩展到 0 到 9 之间的每个 i,并且从不使用缓存。相比之下,解码方式提供了多个递归路径,因此通过 recurse(i+1) 的路径线性地找到基本情况,然后随着堆栈的展开,recurse(i+2) 尝试找到到达它的其他方式。

添加缓存会切断额外的路径,但对每个中间节点return没有任何价值。使用缓存,就好像您有子问题的记忆化或动态编程 table,但您从不更新任何条目,因此整个 table 为零(基本情况除外)。

这是缓存导致的线性行为的示例:

from functools import lru_cache

def cached():
    @lru_cache(None)
    def cached_recurse(i=0):
        print("cached", i)

        if i < 3:
            cached_recurse(i + 1)
            cached_recurse(i + 2)

    cached_recurse()

def uncached():
    def uncached_recurse(i=0):
        print("uncached", i)

        if i < 3:
            uncached_recurse(i + 1)
            uncached_recurse(i + 2)

    uncached_recurse()

cached()
uncached()

输出:

cached 0
cached 1
cached 2
cached 3
cached 4
uncached 0
uncached 1
uncached 2
uncached 3
uncached 4
uncached 3
uncached 2
uncached 3
uncached 4

解决方案与您展示的完全一样:将结果向上传递到树中,并使用缓存来存储表示子问题的每个节点的值。这是两全其美的做法:我们拥有子问题的值,但无需重新执行最终导致 ways += 1 基本情况的函数。

换句话说,如果您要使用缓存,请将其视为查找 table,而不仅仅是调用树修剪器。在您的尝试中,它不记得完成了什么工作,只是阻止它再次完成。