a 和 b 的导数,使用算法微分

derivatives by a and b, using using algorithmic differentiation

我的任务是使用 jax 找到 a 和 b 的导数,为此 function

现在,我来这里的原因是因为我了解得不够Python,而且对于有问题的课程,我们也没有想到python。

作业是:

return a tuple (dfa, dfb) such that dfa is the partial derivatives of f by a,
           and dfb is the partial derivative of f by b

现在,我可以用正常的方式完成了:

def function(a, b):
   dfa = sym.diff((2/b)*sym.cos(a)*sym.exp(-a*a/b*b), a)
   dfb = sym.diff((2/b)*sym.cos(a)*sym.exp(-a*a/b*b), a)
   return (dfa, dfb)

但我不熟悉算法微分,使用我们给出的例子,我试过这个:

def foo():

   x = (2/b)*sym.cos(a)
   y = sym.exp(-sym.Pow(a/b,2))
   return (x*y)

def f_partial_derviatives_algo():
   return jax.grad(foo)

但我收到此错误:

cannot unpack non-iterable function object

如果有人能帮助理解我如何做这样的事情,将不胜感激

JAX 和 sympy 不兼容。您应该使用其中之一,而不是尝试将两者结合使用。

如果你想使用 JAX 计算这个函数在某个值的偏导数,你可以这样写:

import jax.numpy as jnp
from jax import grad

def f(a, b):
  return (2 / b) * jnp.cos(a) * jnp.exp(- a ** 2 / b ** 2)

df_da = grad(f, argnums=0)
df_db = grad(f, argnums=1)

print(df_da(1.0, 1.0), df_db(1.0, 1.0))
# -1.4141841 0.3975322