简单的链式法则 python

Chain Rule in plain python

我看到 sympy 的问题已经得到解答,但我正在尝试在没有第三方库的玩具项目上编写链式法则的实现以用于教育目的。

基本上链式规则是k'(x) = f'(g(x)) * g'(x) where k(x) = f(g(x))

我有以下功能:

def g(x):
    return x**3 + 2

def f(x):
    return x**2 + 7

def de(fn, x, step):
    t1 = fn(x)
    t2 = fn(x+step)
    return (t2 - t1) / step

def chain(x):
    return f(g(x))

def de_chain(x, step):
    d_g = de(g, x, step)
    gres = g(x)
    d_f_g = de(f, gres, step)
    return d_g * d_f_g

问题是当我为 x=1.2step=2.6 计算 de_chainde(chain) 时,我得到 de(chain) = 205.5446...de_chain = 1238.6639...

这里有问题,因为与 k'(x) = g'(x) + f'(x) where k(x) = g(x) + f(x) 中相同的方法应用于加法和减法 结果非常非常接近。我做错了什么?

谢谢

您的代码看起来正确。问题是,一般来说,仅用一个差值进行导数估计并不是非常准确,而且这里你的步长非常大。请记住导数是您的 de 函数,但它是 limit 随着 step 变为 0.

只考虑你的 g(x)。它在 x=1 处的实际导数是 3*x^2 = 3 * 1^2 = 3。但是如果你的步长为 2.6,你会得到 4.6 的估计值,这与目标相去甚远。

您可以在此处阅读更准确的导数估算方法:https://en.wikipedia.org/wiki/Numerical_differentiation