a 和 b 的导数,使用算法微分
derivatives by a and b, using using algorithmic differentiation
我的任务是使用 jax 找到 a 和 b 的导数,为此
function
现在,我来这里的原因是因为我了解得不够Python,而且对于有问题的课程,我们也没有想到python。
作业是:
return a tuple (dfa, dfb) such that dfa is the partial derivatives of f by a,
and dfb is the partial derivative of f by b
现在,我可以用正常的方式完成了:
def function(a, b):
dfa = sym.diff((2/b)*sym.cos(a)*sym.exp(-a*a/b*b), a)
dfb = sym.diff((2/b)*sym.cos(a)*sym.exp(-a*a/b*b), a)
return (dfa, dfb)
但我不熟悉算法微分,使用我们给出的例子,我试过这个:
def foo():
x = (2/b)*sym.cos(a)
y = sym.exp(-sym.Pow(a/b,2))
return (x*y)
def f_partial_derviatives_algo():
return jax.grad(foo)
但我收到此错误:
cannot unpack non-iterable function object
如果有人能帮助理解我如何做这样的事情,将不胜感激
JAX 和 sympy 不兼容。您应该使用其中之一,而不是尝试将两者结合使用。
如果你想使用 JAX 计算这个函数在某个值的偏导数,你可以这样写:
import jax.numpy as jnp
from jax import grad
def f(a, b):
return (2 / b) * jnp.cos(a) * jnp.exp(- a ** 2 / b ** 2)
df_da = grad(f, argnums=0)
df_db = grad(f, argnums=1)
print(df_da(1.0, 1.0), df_db(1.0, 1.0))
# -1.4141841 0.3975322
我的任务是使用 jax 找到 a 和 b 的导数,为此 function
现在,我来这里的原因是因为我了解得不够Python,而且对于有问题的课程,我们也没有想到python。
作业是:
return a tuple (dfa, dfb) such that dfa is the partial derivatives of f by a,
and dfb is the partial derivative of f by b
现在,我可以用正常的方式完成了:
def function(a, b):
dfa = sym.diff((2/b)*sym.cos(a)*sym.exp(-a*a/b*b), a)
dfb = sym.diff((2/b)*sym.cos(a)*sym.exp(-a*a/b*b), a)
return (dfa, dfb)
但我不熟悉算法微分,使用我们给出的例子,我试过这个:
def foo():
x = (2/b)*sym.cos(a)
y = sym.exp(-sym.Pow(a/b,2))
return (x*y)
def f_partial_derviatives_algo():
return jax.grad(foo)
但我收到此错误:
cannot unpack non-iterable function object
如果有人能帮助理解我如何做这样的事情,将不胜感激
JAX 和 sympy 不兼容。您应该使用其中之一,而不是尝试将两者结合使用。
如果你想使用 JAX 计算这个函数在某个值的偏导数,你可以这样写:
import jax.numpy as jnp
from jax import grad
def f(a, b):
return (2 / b) * jnp.cos(a) * jnp.exp(- a ** 2 / b ** 2)
df_da = grad(f, argnums=0)
df_db = grad(f, argnums=1)
print(df_da(1.0, 1.0), df_db(1.0, 1.0))
# -1.4141841 0.3975322