vmap 在尝试计算每个样本的梯度时给出不一致的形状错误

vmap gives inconsistent shape error when trying to calculate gradient per sample

我正在尝试实现一个双层神经网络并获取每个样本的第二层的梯度。

我的代码如下所示:

x = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=jnp.float32)
y = jnp.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=jnp.float32)

W1 = jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=jnp.float32)
W2 = random.uniform(key, shape=(10,), minval=1, maxval=2, dtype=jnp.float32)
b = jnp.linspace(0, -9, 10, dtype=jnp.float32)
b = jnp.reshape(b, (1,10))

def predict(W1, W2, b, x):
  f1 = jnp.einsum('i,j->ji', W1, x)+b
  f1 = nn.relu(f1)
  f2 = jnp.einsum('i,ji->j', W2, f1)
  return f2

def loss(W1, W2, b, x, y):
  preds = predict(W1, W2, b, x)
  return jnp.mean(jnp.square(y-preds))

perex_grads = vmap(grad(loss, argnums=1), in_axes= (0, None, 0, 0, 0))
pers_grads = perex_grads(W1, W2, b, x, y)

我运行输了,能做到grad(loss)就好了。 运行 vmap 是实际问题。

我得到的确切错误是:

ValueError: vmap got inconsistent sizes for array axes to be mapped:
arg 0 has shape (10,) and axis 0 is to be mapped
arg 1 has shape (10,) and axis None is to be mapped
arg 2 has shape (1, 10) and axis 0 is to be mapped
arg 3 has shape (11,) and axis 0 is to be mapped
arg 4 has shape (11,) and axis 0 is to be mapped
so
arg 0 has an axis to be mapped of size 10
arg 2 has an axis to be mapped of size 1
args 3, 4 have axes to be mapped of size 11

这是我第一次使用 Jax,我的 google 搜索没有帮助我解决问题,而且文档对我来说不是很清楚。如果有人能帮助我,我将不胜感激。

问题正是错误消息所说的:为了对多个数组进行 vmap 操作,每个数组中映射轴的维度必须相等。在您的数组中,维度不相等:您为参数 W1, W2, b, x, y 传递了 in_axes=(0, None, 0, 0, 0),但是 W1.shape[0] = 10b.shape[0] = 1x.shape[0] = 11y.shape[0] = 11 .

因为它们不相等,所以你会得到这个错误。为防止此错误,您应该只在相同长度的数组轴上进行 vmap。

例如,如果您想要计算每对 W1, W2 输入的 W2 的梯度,它可能看起来像这样(注意更新的 predict 函数和更新的in_axes):

import jax.numpy as jnp
from jax import random, nn, grad, vmap

key = random.PRNGKey(0)

x = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=jnp.float32)
y = jnp.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0], dtype=jnp.float32)

W1 = jnp.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=jnp.float32)
W2 = random.uniform(key, shape=(10,), minval=1, maxval=2, dtype=jnp.float32)
b = jnp.linspace(0, -9, 10, dtype=jnp.float32)
b = jnp.reshape(b, (1,10))

def predict(W1, W2, b, x):
  # Since you're vmapping over W1 and W2, your function needs to
  # handle scalar values, so we cast to 1D if necessary.
  W1 = jnp.atleast_1d(W1)
  W2 = jnp.atleast_1d(W2)

  f1 = jnp.einsum('i,j->ji', W1, x)+b
  f1 = nn.relu(f1)
  f2 = jnp.einsum('i,ji->j', W2, f1)
  return f2

def loss(W1, W2, b, x, y):
  preds = predict(W1, W2, b, x)
  return jnp.mean(jnp.square(y-preds))

perex_grads = vmap(grad(loss, argnums=1), in_axes= (0, 0, None, None, None))
pers_grads = perex_grads(W1, W2, b, x, y)