使用 vmap (jax) 按元素对矩阵求和?
sum matrix elementwise using vmap (jax)?
我正在尝试了解 vmap 中的 in_axes 和 out_axes 选项。
例如,我想对两个矩阵求和并得到具有相同形状的输出。
X = np.arange(9).reshape(3,3)
Y = np.arange(0,-9,-1).reshape(3,3)
def sum2(x,y):
return x + y
vmap(sum2,in_axes=((0,1),(0,1)))(X,Y)
我想我分别为 X 和 Y 映射了轴 0 和轴 1。输出将具有与 X,Y 相同的形状。
但是我得到了错误,
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-403-103694166574> in <module>
3 def sum2(x,y):
4 return x + y
----> 5 vmap(sum2,in_axes=((0,1),(0,1)))(X,Y)
[... skipping hidden 2 frame]
~/anaconda3/lib/python3.8/site-packages/jax/api_util.py in flatten_axes(name, treedef, axis_tree, kws)
276 assert treedef_is_leaf(leaf)
277 axis_tree, _ = axis_tree
--> 278 raise ValueError(f"{name} specification must be a tree prefix of the "
279 f"corresponding value, got specification {axis_tree} "
280 f"for value tree {treedef}.") from None
ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification ((0, 1), (0, 1)) for value tree PyTreeDef((*, *)).
首先,做元素求和最简单的方法是使用内置的二元运算广播,直接调用sum2(X, Y)
。
就是说,如果您想了解 vmap
:问题是 vmap
一次只能映射一个轴。如果要映射多个轴,可以嵌套多个vmap。我相信你的意图可以这样表达:
from jax import vmap
import jax.numpy as np
X = np.arange(9).reshape(3,3)
Y = np.arange(0,-9,-1).reshape(3,3)
def sum2(x,y):
assert x.ndim == y.ndim == 0
return x + y
vmap(vmap(sum
vmap(sum2, in_axes=(0, 0), out_axes=0),
in_axes=(1, 1), out_axes=1
)(X,Y)
注意:我添加了关于维数的断言,以证明正在对标量值调用映射函数。
另外,请注意当映射轴匹配时,例如in_axes=(0, 0)
可以等效地写成 in_axes=0
,但我将它保留为元组,因为它更接近您尝试的语法。
事实上,用嵌套 vmap
进行相同计算的一种更简洁的方法是使用默认参数:vmap(vmap(sum2))(X, Y)
将进行相同的元素求和。
我正在尝试了解 vmap 中的 in_axes 和 out_axes 选项。 例如,我想对两个矩阵求和并得到具有相同形状的输出。
X = np.arange(9).reshape(3,3)
Y = np.arange(0,-9,-1).reshape(3,3)
def sum2(x,y):
return x + y
vmap(sum2,in_axes=((0,1),(0,1)))(X,Y)
我想我分别为 X 和 Y 映射了轴 0 和轴 1。输出将具有与 X,Y 相同的形状。 但是我得到了错误,
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-403-103694166574> in <module>
3 def sum2(x,y):
4 return x + y
----> 5 vmap(sum2,in_axes=((0,1),(0,1)))(X,Y)
[... skipping hidden 2 frame]
~/anaconda3/lib/python3.8/site-packages/jax/api_util.py in flatten_axes(name, treedef, axis_tree, kws)
276 assert treedef_is_leaf(leaf)
277 axis_tree, _ = axis_tree
--> 278 raise ValueError(f"{name} specification must be a tree prefix of the "
279 f"corresponding value, got specification {axis_tree} "
280 f"for value tree {treedef}.") from None
ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification ((0, 1), (0, 1)) for value tree PyTreeDef((*, *)).
首先,做元素求和最简单的方法是使用内置的二元运算广播,直接调用sum2(X, Y)
。
就是说,如果您想了解 vmap
:问题是 vmap
一次只能映射一个轴。如果要映射多个轴,可以嵌套多个vmap。我相信你的意图可以这样表达:
from jax import vmap
import jax.numpy as np
X = np.arange(9).reshape(3,3)
Y = np.arange(0,-9,-1).reshape(3,3)
def sum2(x,y):
assert x.ndim == y.ndim == 0
return x + y
vmap(vmap(sum
vmap(sum2, in_axes=(0, 0), out_axes=0),
in_axes=(1, 1), out_axes=1
)(X,Y)
注意:我添加了关于维数的断言,以证明正在对标量值调用映射函数。
另外,请注意当映射轴匹配时,例如in_axes=(0, 0)
可以等效地写成 in_axes=0
,但我将它保留为元组,因为它更接近您尝试的语法。
事实上,用嵌套 vmap
进行相同计算的一种更简洁的方法是使用默认参数:vmap(vmap(sum2))(X, Y)
将进行相同的元素求和。