in_axes JAX 的 vmap 中的关键字
in_axes keyword in JAX's vmap
我正在尝试使用 vmap
了解 JAX 的自动矢量化功能,并根据 JAX 的文档实现了一个最小的工作示例。
我不明白 in_axes
是如何正确使用的。在下面的示例中,我可以设置 in_axes=(None, 0)
或 in_axes=(None, 1)
来获得相同的结果。为什么会这样?
为什么我必须使用 in_axes=(None, 0)
而不是 in_axes=(0, )
之类的东西?
import jax.numpy as jnp
from jax import vmap
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = input_vec
for W, b in params:
outputs = jnp.dot(W, activations) + b
activations = jnp.tanh(outputs)
return outputs
if __name__ == "__main__":
# Parameters
dims = [2, 3, 5]
input_dims = dims[0]
batch_size = 2
# Weights
params = list()
for dims_in, dims_out in zip(dims, dims[1:]):
params.append((jnp.ones((dims_out, dims_in)), jnp.ones((dims_out,))))
# Input data
input_batch = jnp.ones((batch_size, input_dims))
# With vmap
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
print(predictions)
in_axes=(None, 0)
表示第一个参数(此处为 params
)将不会被映射,而第二个参数(此处为 input_vec
)将沿轴 0 进行映射。
In the example below I can set in_axes=(None, 0)
or in_axes=(None, 1)
leading to the same results. Why is that the case?
这是因为 input_vec
是一个 2x2 矩阵,所以无论您是沿轴 0 还是轴 1 映射,输入向量都是长度为 2 的向量。在更一般的情况下,这两个规范是不等价的,您可以通过 (1) 使 batch_size
与 input_dims[0]
不同,或 (2) 用非常量值填充数组来看出这一点。
why do I have to use in_axes=(None, 0)
and not something like in_axes=(0, )
?
如果为具有两个参数的函数设置 in_axes=(0, )
,则会出现错误,因为 in_axes
元组的长度必须与传递给函数的参数数量匹配。也就是说,可以将标量 in_axes=0
作为 shorthand 传递给 in_axes=(0, 0)
,但是对于您的函数,这会导致形状错误,因为 [=] 中数组的前导维度11=] 与 input_vec
.
的前导维度不匹配
我正在尝试使用 vmap
了解 JAX 的自动矢量化功能,并根据 JAX 的文档实现了一个最小的工作示例。
我不明白 in_axes
是如何正确使用的。在下面的示例中,我可以设置 in_axes=(None, 0)
或 in_axes=(None, 1)
来获得相同的结果。为什么会这样?
为什么我必须使用 in_axes=(None, 0)
而不是 in_axes=(0, )
之类的东西?
import jax.numpy as jnp
from jax import vmap
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = input_vec
for W, b in params:
outputs = jnp.dot(W, activations) + b
activations = jnp.tanh(outputs)
return outputs
if __name__ == "__main__":
# Parameters
dims = [2, 3, 5]
input_dims = dims[0]
batch_size = 2
# Weights
params = list()
for dims_in, dims_out in zip(dims, dims[1:]):
params.append((jnp.ones((dims_out, dims_in)), jnp.ones((dims_out,))))
# Input data
input_batch = jnp.ones((batch_size, input_dims))
# With vmap
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
print(predictions)
in_axes=(None, 0)
表示第一个参数(此处为 params
)将不会被映射,而第二个参数(此处为 input_vec
)将沿轴 0 进行映射。
In the example below I can set
in_axes=(None, 0)
orin_axes=(None, 1)
leading to the same results. Why is that the case?
这是因为 input_vec
是一个 2x2 矩阵,所以无论您是沿轴 0 还是轴 1 映射,输入向量都是长度为 2 的向量。在更一般的情况下,这两个规范是不等价的,您可以通过 (1) 使 batch_size
与 input_dims[0]
不同,或 (2) 用非常量值填充数组来看出这一点。
why do I have to use
in_axes=(None, 0)
and not something likein_axes=(0, )
?
如果为具有两个参数的函数设置 in_axes=(0, )
,则会出现错误,因为 in_axes
元组的长度必须与传递给函数的参数数量匹配。也就是说,可以将标量 in_axes=0
作为 shorthand 传递给 in_axes=(0, 0)
,但是对于您的函数,这会导致形状错误,因为 [=] 中数组的前导维度11=] 与 input_vec
.