基于秩的计算的自动微分

Automatic Differentiation with respect to rank-based computations

我是自动微分编程的新手,所以这可能是一个幼稚的问题。以下是我要解决的问题的简化版本。

我有两个输入数组 - 一个大小为 N 的向量 A 和一个形状为 (N, M) 的矩阵 B,以及一个参数向量 theta 大小 M。我定义了一个新数组 C(theta) = B * theta 以获得大小为 N 的新向量。然后,我获取落在 C 的上四分位数和下四分位数中的元素的索引,并使用它们创建一个新数组 A_low(theta) = A[lower quartile indices of C]A_high(theta) = A[upper quartile indices of C]。显然这两个确实依赖于theta,但是是否可以区分A_lowA_highw.r.ttheta

到目前为止,我的尝试似乎表明没有 - 我使用了 autograd、JAX 和 tensorflow 的 python 库,但它们都 return 梯度为零。 (到目前为止,我尝试过的方法包括使用 argsort 或使用 tf.top_k 提取相关子数组。)

我正在寻求帮助的是证明导数未定义(或无法分析计算)或者如果它确实存在,关于如何估计它的建议。 我的最终目标是最小化某些函数 f(A_low, A_high) wrt theta.

这是我根据你的描述写的JAX计算:

import numpy as np
import jax.numpy as jnp
import jax

N = 10
M = 20

rng = np.random.default_rng(0)
A = jnp.array(rng.random((N,)))
B = jnp.array(rng.random((N, M)))
theta = jnp.array(rng.random(M))

def f(A, B, theta, k=3):
  C = B @ theta
  _, i_upper = lax.top_k(C, k)
  _, i_lower = lax.top_k(-C, k)
  return A[i_lower], A[i_upper]

x, y = f(A, B, theta)
dx_dtheta, dy_dtheta = jax.jacobian(f, argnums=2)(A, B, theta)

导数全为零,我相信这是正确的,因为输出值的变化不依赖于theta.

值的变化

但是,您可能会问,这怎么可能?毕竟,theta 进入计算,如果你为 theta 输入不同的值,你会得到不同的输出。梯度怎么可能为零?

不过,您必须牢记的是,微分并不能衡量输入是否影响输出。在给定输入 .

的无穷小变化的情况下,它测量输出的 变化

我们以一个稍微简单一点的函数为例:

import jax
import jax.numpy as jnp

A = jnp.array([1.0, 2.0, 3.0])
theta = jnp.array([5.0, 1.0, 3.0])

def f(A, theta):
  return A[jnp.argmax(theta)]

x = f(A, theta)
dx_dtheta = jax.grad(f, argnums=1)(A, theta)

此处ftheta的微分结果全为零,原因同上。为什么?如果对 theta 进行微小的更改,通常不会影响 theta 的排序顺序。因此,您从 A 中选择的条目不会随着 theta 的无穷小变化而改变,因此关于 theta 的导数为零。

现在,您可能会争辩说在某些情况下情况并非如此:例如,如果 theta 中的两个值非常接近,那么即使是无穷小地扰动一个值也肯定会改变它们各自的等级。这是真的,但是这个过程产生的梯度是不确定的(输出的变化相对于输入的变化并不平滑)。好消息是这种不连续性是单方面的:如果你在另一个方向上进行扰动,等级没有变化并且梯度是明确定义的。为了避免未定义的梯度,大多数自动微分系统将隐含地使用这种更安全的导数定义来进行基于秩的计算。

结果是,当您无限微扰输入时,输出值不会改变,这是梯度为零的另一种说法。这并不是 autodiff 的失败——考虑到 autodiff 所基于的微分定义,它是正确的梯度。此外,如果您尝试在这些不连续处更改导数的不同定义,您可能希望的最好结果是未定义的输出,因此导致零的定义可以说更有用和正确。