查找所提供序列的第 N 项

Find Nth term of provided sequence

f(0) = p

f(1) = q

f(2) = r

for n > 2

f(n) = af(n-1) + bf(n-2) + c*f(n-3) + g(n)

where g(n) = n* n* (n+1)

p,q,r,a,b,c 已给出 问题是,如何找到这个系列的第 n 项。

请帮助我找到更好的解决方案。

我试过使用递归来解决这个问题。但是这样很耗内存。

比递归更好的方法是memoization。您只需要知道 f(n) 的最后三个值。 伪代码中的解决方案可能如下所示:

if n == 0:
    return p
else if n == 1:
    return q
else if n == 2:
    return r
else:    
    f_n-3 = p
    f_n-2 = q
    f_n-1 = r
    for i from 3 to n:
        f_new = a * f_n-1 + b * f_n-2 + c * f_n-3 + g(n)
        fn-1 = fn-2
        fn-2 = fn-3
        fn-3 = f_new

return f_new

这样你就不需要递归调用方法并将所有计算的值保存在堆栈中,而只需在内存中保存 4 个变量。

这应该会计算得更快并且使用更少的内存。

问题是,每次使用 n > 2 调用 f,都会导致对 f 的三个额外调用。例如,如果我们调用 f(5),我们会得到以下调用:

- f(5)
    - f(4)
        - f(3)
            - f(2)
            - f(1)
            - f(0)
            - g(3)
        - f(2)
        - f(1)
        - g(4)
    - f(3)
        - f(2)
        - f(1)
        - f(0)
        - g(3)
    - f(2)
    - g(5)

因此我们调用了一次 f(5),调用了一次 f(4),调用了两次 f(3),调用了四次 f(2),调用了三次 f(1),以及对 f(0).

的两次调用

由于我们多次调用 f(3),这意味着每次都会消耗资源,特别是因为 f(3) 本身会进行额外的调用。

我们可以让Python存储函数调用的结果,而return存储结果,例如lru_cache [Python-doc]。这种技术称为 memoization:

from functools import <b>lru_cache</b>

def g(n):
    return n * n * (n+1)

<b>@lru_cache(maxsize=32)</b>
def f(n):
    if n <= 2:
        return (p, q, r)[n]
    else:
        return a*f(n-1) + b*f(n-2) + c*f(n-3) + g(n)

这将导致调用图如下:

- f(5)
    - f(4)
        - f(3)
            - f(2)
            - f(1)
            - f(0)
            - g(3)
        - g(4)
    - g(5)

所以现在我们只计算一次f(3)lru_cache会把它存入缓存,如果我们第二次调用f(3),我们永远不会计算f(3) 本身,缓存将 return pre-computed 值。

不过这里可以优化一下,因为我们每次调用f(n-1)f(n-2)f(n-3),我们只需要存储最后三个值,每次计算基于最后三个值的下一个值,并移动变量,如:

def f(n):
    if n <= 2:
        return (p, q, r)[n]
    f3, f2, f1 = p, q, r
    for i in range(3, n+1):
        f3, f2, f1 = f2, f1, a * f1 + b * f2 + c * f3 + g(i)
    return f1