Jax:对向量值参数的索引取导数

Jax: Take derivative with respect to index of vector-valued argument

Jax 是否支持取导数w.r.t。向量值变量的索引?考虑这个例子(其中 a 是 vector/array):

def test_func(a):
  return a[0]**a[1]

我可以将参数编号传递给 grad(..),但我似乎无法像上例那样传递向量值参数的索引。我尝试传递一个元组的元组,即

grad(test_func, argnums=((0,),))

但这不起作用。

没有可以对数组的某些元素采用梯度的内置变换,但您可以通过将数组拆分为单个元素的包装函数直接执行此操作;例如:

import jax
import jax.numpy as jnp

def test_func(a):
  return a[0]**a[1]

a = jnp.array([1.0, 2.0])
fgrad = jax.grad(lambda *args: test_func(jnp.array(args)), argnums=0)
print(fgrad(*a))
# 2.0

如果你想对 所有 单独的输入进行梯度处理(返回每个条目的梯度向量),你可以使用 jax.jacobian :

print(jax.jacobian(test_func)(a))
# [2. 0.]