in_axes JAX 的 vmap 中的关键字

in_axes keyword in JAX's vmap

我正在尝试使用 vmap 了解 JAX 的自动矢量化功能,并根据 JAX 的文档实现了一个最小的工作示例。

我不明白 in_axes 是如何正确使用的。在下面的示例中,我可以设置 in_axes=(None, 0)in_axes=(None, 1) 来获得相同的结果。为什么会这样?

为什么我必须使用 in_axes=(None, 0) 而不是 in_axes=(0, ) 之类的东西?

import jax.numpy as jnp
from jax import vmap


def predict(params, input_vec):
    assert input_vec.ndim == 1
    activations = input_vec
    for W, b in params:
        outputs = jnp.dot(W, activations) + b
        activations = jnp.tanh(outputs)
    return outputs


if __name__ == "__main__":

    # Parameters
    dims = [2, 3, 5]
    input_dims = dims[0]
    batch_size = 2

    # Weights
    params = list()
    for dims_in, dims_out in zip(dims, dims[1:]):
        params.append((jnp.ones((dims_out, dims_in)), jnp.ones((dims_out,))))

    # Input data
    input_batch = jnp.ones((batch_size, input_dims))

    # With vmap
    predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
    print(predictions)

in_axes=(None, 0) 表示第一个参数(此处为 params)将不会被映射,而第二个参数(此处为 input_vec)将沿轴 0 进行映射。

In the example below I can set in_axes=(None, 0) or in_axes=(None, 1) leading to the same results. Why is that the case?

这是因为 input_vec 是一个 2x2 矩阵,所以无论您是沿轴 0 还是轴 1 映射,输入向量都是长度为 2 的向量。在更一般的情况下,这两个规范是不等价的,您可以通过 (1) 使 batch_sizeinput_dims[0] 不同,或 (2) 用非常量值填充数组来看出这一点。

why do I have to use in_axes=(None, 0) and not something like in_axes=(0, )?

如果为具有两个参数的函数设置 in_axes=(0, ),则会出现错误,因为 in_axes 元组的长度必须与传递给函数的参数数量匹配。也就是说,可以将标量 in_axes=0 作为 shorthand 传递给 in_axes=(0, 0),但是对于您的函数,这会导致形状错误,因为 [=] 中数组的前导维度11=] 与 input_vec.

的前导维度不匹配