Python:根据新的函数参数调用缓存的函数结果

Python: recall cached function result dependent on new function parameter

我对缓存和记忆化的概念还很陌生。我已经阅读了一些其他的讨论和资源 here, here, and here,但未能很好地理解它们。

假设我在一个 class 中有两个成员函数。 (下面的简化示例。)假设第一个函数 total 计算量很大。第二个函数 subtotal 在计算上很简单,只是它使用了第一个函数的 return,因此计算量也很大,因为它目前需要重新调用 total 得到它的 returned 结果。

我想缓存第一个函数的结果并将其用作第二个函数的输入,if 输入 ysubtotal 份额输入 x 到最近调用 total。即:

示例:

class MyObject(object):

    def __init__(self, a, b):
        self.a, self.b = a, b

    def total(self, x):
        return (self.a + self.b) * x     # some time-expensive calculation

    def subtotal(self, y, z):
        return self.total(x=y) + z       # Don't want to have to re-run total() here
                                         # IF y == x from a recent call of total(),
                                         # otherwise, call total().

对于 Python3.2 或更高版本,您可以使用 functools.lru_cache。 如果你直接用 functools.lru_cache 修饰 total,那么 lru_cache 会根据两个参数的值缓存 total 的 return 值,selfx。由于 lru_cache 的内部字典存储了对 self 的引用,将 @lru_cache 直接应用于 class 方法会创建对 self 的循环引用,从而生成 [=46] 的实例=] un-dereferencable(因此内存泄漏)。

which allows you to use lru_cache with class methods -- it caches results based on all arguments other than the first one, self, and uses a weakref避免循环引用问题:

import functools
import weakref

def memoized_method(*lru_args, **lru_kwargs):
    """
     (orly)
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapped_func(self, *args, **kwargs):
            # We're storing the wrapped method inside the instance. If we had
            # a strong reference to self the instance would never die.
            self_weak = weakref.ref(self)
            @functools.wraps(func)
            @functools.lru_cache(*lru_args, **lru_kwargs)
            def cached_method(*args, **kwargs):
                return func(self_weak(), *args, **kwargs)
            setattr(self, func.__name__, cached_method)
            return cached_method(*args, **kwargs)
        return wrapped_func
    return decorator


class MyObject(object):

    def __init__(self, a, b):
        self.a, self.b = a, b

    @memoized_method()
    def total(self, x):
        print('Calling total (x={})'.format(x))
        return (self.a + self.b) * x


    def subtotal(self, y, z):
        return self.total(x=y) + z 

mobj = MyObject(1,2)
mobj.subtotal(10, 20)
mobj.subtotal(10, 30)

打印

Calling total (x=10)

只有一次。


或者,这是您可以使用字典滚动自己的缓存的方法:

class MyObject(object):

    def __init__(self, a, b):
        self.a, self.b = a, b
        self._total = dict()

    def total(self, x):
        print('Calling total (x={})'.format(x))
        self._total[x] = t = (self.a + self.b) * x
        return t

    def subtotal(self, y, z):
        t = self._total[y] if y in self._total else self.total(y)
        return t + z 

mobj = MyObject(1,2)
mobj.subtotal(10, 20)
mobj.subtotal(10, 30)

lru_cache 优于此基于字典的缓存的一个优点是 lru_cache 是线程安全的。 lru_cache 还有一个 maxsize 参数可以帮助 防止内存使用无限制地增长(例如,由于 long-运行 进程使用不同的 x).

多次调用 total

感谢大家的回复,阅读它们并了解幕后发生的事情很有帮助。正如@Tadhg McDonald-Jensen 所说,似乎我在这里只需要 @functools.lru_cache。 (我在 Python 3.5。)关于@unutbu 的评论,我没有收到用 @lru_cache 装饰 total() 的错误。让我更正我自己的例子,我会把它留在这里供其他初学者使用:

from functools import lru_cache
from datetime import datetime as dt

class MyObject(object):
    def __init__(self, a, b):
        self.a, self.b = a, b

    @lru_cache(maxsize=None)
    def total(self, x):        
        lst = []
        for i in range(int(1e7)):
            val = self.a + self.b + x    # time-expensive loop
            lst.append(val)
        return np.array(lst)     

    def subtotal(self, y, z):
        return self.total(x=y) + z       # if y==x from a previous call of
                                         # total(), used cached result.

myobj = MyObject(1, 2)

# Call total() with x=20
a = dt.now()
myobj.total(x=20)
b = dt.now()
c = (b - a).total_seconds()

# Call subtotal() with y=21
a2 = dt.now()
myobj.subtotal(y=21, z=1)
b2 = dt.now()
c2 = (b2 - a2).total_seconds()

# Call subtotal() with y=20 - should take substantially less time
# with x=20 used in previous call of total().
a3 = dt.now()
myobj.subtotal(y=20, z=1)
b3 = dt.now()
c3 = (b3 - a3).total_seconds()

print('c: {}, c2: {}, c3: {}'.format(c, c2, c3))
c: 2.469753, c2: 2.355764, c3: 0.016998

在这种情况下,我会做一些简单的事情,也许不是最优雅的方式,但可以解决问题:

class MyObject(object):
    param_values = {}
    def __init__(self, a, b):
        self.a, self.b = a, b

    def total(self, x):
        if x not in MyObject.param_values:
          MyObject.param_values[x] = (self.a + self.b) * x
          print(str(x) + " was never called before")
        return MyObject.param_values[x]

    def subtotal(self, y, z):
        if y in MyObject.param_values:
          return MyObject.param_values[y] + z
        else:
          return self.total(y) + z