vmap 在 jax 中的列表
vmap over a list in jax
我尝试使用 jax 计算每个样本的梯度,对其进行处理,然后将它们置于正常形式以计算正常参数更新。
我的工作代码看起来像
differentiate_per_sample = jit(vmap(grad(loss), in_axes=(None, 0, 0)))
gradients = differentiate_per_sample(params, x, y)
# some code
gradients_summed_over_samples = []
for layer in gradients:
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
gradients_summed_over_samples.append((dw, db))
其中 gradients
的形式为 list(tuple(DeviceArray(...), DeviceArray(...)), ...)
.
现在我尝试将循环重写为 vmap(不确定最终是否带来加速)
def sum_samples(layer):
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
vmap(sum_samples)(gradients)
但是 sum_samples
只被调用一次,而不是为列表中的每个元素调用。
是列表有问题还是我理解有问题?
jax.vmap
只会映射到 jax 数组输入,而不是数组或元组列表的输入。此外,vmapped 函数不能就地修改输入;函数应该 return 一个值,这个 return 值将与其他 return 值叠加以构造输出
例如,您可以修改您定义的函数并像这样使用它:
import jax.numpy as np
from jax import random
def sum_samples(layer):
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
return np.array([dw, db])
key = random.PRNGKey(1701)
data = random.uniform(key, (10, 2, 20))
result = vmap(sum_samples)(data)
print(result.shape)
# (10, 2)
旁注:如果您使用这种方法,上面的 vmapped 函数可以更简洁地表示为:
def sum_samples(layer):
return layer.sum(1)
我尝试使用 jax 计算每个样本的梯度,对其进行处理,然后将它们置于正常形式以计算正常参数更新。 我的工作代码看起来像
differentiate_per_sample = jit(vmap(grad(loss), in_axes=(None, 0, 0)))
gradients = differentiate_per_sample(params, x, y)
# some code
gradients_summed_over_samples = []
for layer in gradients:
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
gradients_summed_over_samples.append((dw, db))
其中 gradients
的形式为 list(tuple(DeviceArray(...), DeviceArray(...)), ...)
.
现在我尝试将循环重写为 vmap(不确定最终是否带来加速)
def sum_samples(layer):
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
vmap(sum_samples)(gradients)
但是 sum_samples
只被调用一次,而不是为列表中的每个元素调用。
是列表有问题还是我理解有问题?
jax.vmap
只会映射到 jax 数组输入,而不是数组或元组列表的输入。此外,vmapped 函数不能就地修改输入;函数应该 return 一个值,这个 return 值将与其他 return 值叠加以构造输出
例如,您可以修改您定义的函数并像这样使用它:
import jax.numpy as np
from jax import random
def sum_samples(layer):
(dw, db) = layer
(dw, db) = (np.sum(dw, axis=0), np.sum(db, axis=0))
return np.array([dw, db])
key = random.PRNGKey(1701)
data = random.uniform(key, (10, 2, 20))
result = vmap(sum_samples)(data)
print(result.shape)
# (10, 2)
旁注:如果您使用这种方法,上面的 vmapped 函数可以更简洁地表示为:
def sum_samples(layer):
return layer.sum(1)