了解 PyTorch einsum
Understanding PyTorch einsum
我很熟悉einsum
works in NumPy. A similar functionality is also offered by PyTorch: torch.einsum()。在功能或性能方面有何异同? PyTorch 文档中提供的信息相当少,没有提供任何关于此的见解。
由于火炬文档中对 einsum 的描述很简陋,我决定写这个 post 来记录、比较和对比如何 torch.einsum()
behaves when compared to numpy.einsum()
。
差异:
NumPy 允许“下标字符串”的小写字母和大写字母 [a-zA-Z]
而 PyTorch 只允许小写字母 [a-z]
.
NumPy 接受 nd 数组、普通 Python 列表(或元组)、列表列表(或元组元组、元组列表、列表元组)甚至 PyTorch 张量作为操作数(即输入)。这是因为 操作数 只需 array_like 而不是严格的 NumPy nd 数组。相反,PyTorch 期望 operands(即输入)严格是 PyTorch 张量。如果您传递纯 Python lists/tuples(或其组合)或 NumPy nd 数组,它将抛出 TypeError
。
除了 nd-arrays
之外,NumPy 还支持很多关键字参数(例如 optimize
),而 PyTorch 还没有提供这种灵活性。
以下是一些示例在 PyTorch 和 NumPy 中的实现:
# input tensors to work with
In [16]: vec
Out[16]: tensor([0, 1, 2, 3])
In [17]: aten
Out[17]:
tensor([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]])
In [18]: bten
Out[18]:
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]])
1)矩阵乘法
火炬:torch.matmul(aten, bten)
; aten.mm(bten)
麻木:np.einsum("ij, jk -> ik", arr1, arr2)
In [19]: torch.einsum('ij, jk -> ik', aten, bten)
Out[19]:
tensor([[130, 130, 130, 130],
[230, 230, 230, 230],
[330, 330, 330, 330],
[430, 430, 430, 430]])
2) 沿主对角线提取元素
火炬:torch.diag(aten)
麻木:np.einsum("ii -> i", arr)
In [28]: torch.einsum('ii -> i', aten)
Out[28]: tensor([11, 22, 33, 44])
3) Hadamard 乘积(即两个张量的逐元素乘积)
火炬:aten * bten
麻木:np.einsum("ij, ij -> ij", arr1, arr2)
In [34]: torch.einsum('ij, ij -> ij', aten, bten)
Out[34]:
tensor([[ 11, 12, 13, 14],
[ 42, 44, 46, 48],
[ 93, 96, 99, 102],
[164, 168, 172, 176]])
4) 逐元素平方
火炬:aten ** 2
麻木:np.einsum("ij, ij -> ij", arr, arr)
In [37]: torch.einsum('ij, ij -> ij', aten, aten)
Out[37]:
tensor([[ 121, 144, 169, 196],
[ 441, 484, 529, 576],
[ 961, 1024, 1089, 1156],
[1681, 1764, 1849, 1936]])
General: Element-wise nth
power 可以通过重复下标串和tensor n
次。
例如,计算张量的逐元素四次方可以使用:
# NumPy: np.einsum('ij, ij, ij, ij -> ij', arr, arr, arr, arr)
In [38]: torch.einsum('ij, ij, ij, ij -> ij', aten, aten, aten, aten)
Out[38]:
tensor([[ 14641, 20736, 28561, 38416],
[ 194481, 234256, 279841, 331776],
[ 923521, 1048576, 1185921, 1336336],
[2825761, 3111696, 3418801, 3748096]])
5) 迹线(即主对角线元素之和)
火炬:torch.trace(aten)
NumPy einsum:np.einsum("ii -> ", arr)
In [44]: torch.einsum('ii -> ', aten)
Out[44]: tensor(110)
6) 矩阵转置
火炬:torch.transpose(aten, 1, 0)
NumPy einsum:np.einsum("ij -> ji", arr)
In [58]: torch.einsum('ij -> ji', aten)
Out[58]:
tensor([[11, 21, 31, 41],
[12, 22, 32, 42],
[13, 23, 33, 43],
[14, 24, 34, 44]])
7) 外积(向量)
火炬:torch.ger(vec, vec)
NumPy einsum:np.einsum("i, j -> ij", vec, vec)
In [73]: torch.einsum('i, j -> ij', vec, vec)
Out[73]:
tensor([[0, 0, 0, 0],
[0, 1, 2, 3],
[0, 2, 4, 6],
[0, 3, 6, 9]])
8) 内积(向量)
火炬:torch.dot(vec1, vec2)
NumPy einsum:np.einsum("i, i -> ", vec1, vec2)
In [76]: torch.einsum('i, i -> ', vec, vec)
Out[76]: tensor(14)
9) 沿轴求和 0
火炬:torch.sum(aten, 0)
NumPy einsum:np.einsum("ij -> j", arr)
In [85]: torch.einsum('ij -> j', aten)
Out[85]: tensor([104, 108, 112, 116])
10) 沿轴 1 求和
火炬:torch.sum(aten, 1)
NumPy einsum:np.einsum("ij -> i", arr)
In [86]: torch.einsum('ij -> i', aten)
Out[86]: tensor([ 50, 90, 130, 170])
11) 批量矩阵乘法
火炬:torch.bmm(batch_tensor_1, batch_tensor_2)
麻木:np.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)
# input batch tensors to work with
In [13]: batch_tensor_1 = torch.arange(2 * 4 * 3).reshape(2, 4, 3)
In [14]: batch_tensor_2 = torch.arange(2 * 3 * 4).reshape(2, 3, 4)
In [15]: torch.bmm(batch_tensor_1, batch_tensor_2)
Out[15]:
tensor([[[ 20, 23, 26, 29],
[ 56, 68, 80, 92],
[ 92, 113, 134, 155],
[ 128, 158, 188, 218]],
[[ 632, 671, 710, 749],
[ 776, 824, 872, 920],
[ 920, 977, 1034, 1091],
[1064, 1130, 1196, 1262]]])
# sanity check with the shapes
In [16]: torch.bmm(batch_tensor_1, batch_tensor_2).shape
Out[16]: torch.Size([2, 4, 4])
# batch matrix multiply using einsum
In [17]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)
Out[17]:
tensor([[[ 20, 23, 26, 29],
[ 56, 68, 80, 92],
[ 92, 113, 134, 155],
[ 128, 158, 188, 218]],
[[ 632, 671, 710, 749],
[ 776, 824, 872, 920],
[ 920, 977, 1034, 1091],
[1064, 1130, 1196, 1262]]])
# sanity check with the shapes
In [18]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2).shape
12) 沿轴 2 求和
火炬:torch.sum(batch_ten, 2)
NumPy einsum:np.einsum("ijk -> ij", arr3D)
In [99]: torch.einsum("ijk -> ij", batch_ten)
Out[99]:
tensor([[ 50, 90, 130, 170],
[ 4, 8, 12, 16]])
13) 对nD张量中的所有元素求和
火炬:torch.sum(batch_ten)
NumPy einsum:np.einsum("ijk -> ", arr3D)
In [101]: torch.einsum("ijk -> ", batch_ten)
Out[101]: tensor(480)
14) 多轴求和(即边缘化)
火炬:torch.sum(arr, dim=(dim0, dim1, dim2, dim3, dim4, dim6, dim7))
NumPy: np.einsum("ijklmnop -> n", nDarr)
# 8D tensor
In [103]: nDten = torch.randn((3,5,4,6,8,2,7,9))
In [104]: nDten.shape
Out[104]: torch.Size([3, 5, 4, 6, 8, 2, 7, 9])
# marginalize out dimension 5 (i.e. "n" here)
In [111]: esum = torch.einsum("ijklmnop -> n", nDten)
In [112]: esum
Out[112]: tensor([ 98.6921, -206.0575])
# marginalize out axis 5 (i.e. sum over rest of the axes)
In [113]: tsum = torch.sum(nDten, dim=(0, 1, 2, 3, 4, 6, 7))
In [115]: torch.allclose(tsum, esum)
Out[115]: True
15) 双点积/Frobenius inner product(等同于:torch.sum(hadamard-product)cf. 3)
火炬:torch.sum(aten * bten)
麻木:np.einsum("ij, ij -> ", arr1, arr2)
In [120]: torch.einsum("ij, ij -> ", aten, bten)
Out[120]: tensor(1300)
我很熟悉einsum
works in NumPy. A similar functionality is also offered by PyTorch: torch.einsum()。在功能或性能方面有何异同? PyTorch 文档中提供的信息相当少,没有提供任何关于此的见解。
由于火炬文档中对 einsum 的描述很简陋,我决定写这个 post 来记录、比较和对比如何 torch.einsum()
behaves when compared to numpy.einsum()
。
差异:
NumPy 允许“下标字符串”的小写字母和大写字母
[a-zA-Z]
而 PyTorch 只允许小写字母[a-z]
.NumPy 接受 nd 数组、普通 Python 列表(或元组)、列表列表(或元组元组、元组列表、列表元组)甚至 PyTorch 张量作为操作数(即输入)。这是因为 操作数 只需 array_like 而不是严格的 NumPy nd 数组。相反,PyTorch 期望 operands(即输入)严格是 PyTorch 张量。如果您传递纯 Python lists/tuples(或其组合)或 NumPy nd 数组,它将抛出
TypeError
。除了
nd-arrays
之外,NumPy 还支持很多关键字参数(例如optimize
),而 PyTorch 还没有提供这种灵活性。
以下是一些示例在 PyTorch 和 NumPy 中的实现:
# input tensors to work with
In [16]: vec
Out[16]: tensor([0, 1, 2, 3])
In [17]: aten
Out[17]:
tensor([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]])
In [18]: bten
Out[18]:
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]])
1)矩阵乘法
火炬:torch.matmul(aten, bten)
; aten.mm(bten)
麻木:np.einsum("ij, jk -> ik", arr1, arr2)
In [19]: torch.einsum('ij, jk -> ik', aten, bten)
Out[19]:
tensor([[130, 130, 130, 130],
[230, 230, 230, 230],
[330, 330, 330, 330],
[430, 430, 430, 430]])
2) 沿主对角线提取元素
火炬:torch.diag(aten)
麻木:np.einsum("ii -> i", arr)
In [28]: torch.einsum('ii -> i', aten)
Out[28]: tensor([11, 22, 33, 44])
3) Hadamard 乘积(即两个张量的逐元素乘积)
火炬:aten * bten
麻木:np.einsum("ij, ij -> ij", arr1, arr2)
In [34]: torch.einsum('ij, ij -> ij', aten, bten)
Out[34]:
tensor([[ 11, 12, 13, 14],
[ 42, 44, 46, 48],
[ 93, 96, 99, 102],
[164, 168, 172, 176]])
4) 逐元素平方
火炬:aten ** 2
麻木:np.einsum("ij, ij -> ij", arr, arr)
In [37]: torch.einsum('ij, ij -> ij', aten, aten)
Out[37]:
tensor([[ 121, 144, 169, 196],
[ 441, 484, 529, 576],
[ 961, 1024, 1089, 1156],
[1681, 1764, 1849, 1936]])
General: Element-wise nth
power 可以通过重复下标串和tensor n
次。
例如,计算张量的逐元素四次方可以使用:
# NumPy: np.einsum('ij, ij, ij, ij -> ij', arr, arr, arr, arr)
In [38]: torch.einsum('ij, ij, ij, ij -> ij', aten, aten, aten, aten)
Out[38]:
tensor([[ 14641, 20736, 28561, 38416],
[ 194481, 234256, 279841, 331776],
[ 923521, 1048576, 1185921, 1336336],
[2825761, 3111696, 3418801, 3748096]])
5) 迹线(即主对角线元素之和)
火炬:torch.trace(aten)
NumPy einsum:np.einsum("ii -> ", arr)
In [44]: torch.einsum('ii -> ', aten)
Out[44]: tensor(110)
6) 矩阵转置
火炬:torch.transpose(aten, 1, 0)
NumPy einsum:np.einsum("ij -> ji", arr)
In [58]: torch.einsum('ij -> ji', aten)
Out[58]:
tensor([[11, 21, 31, 41],
[12, 22, 32, 42],
[13, 23, 33, 43],
[14, 24, 34, 44]])
7) 外积(向量)
火炬:torch.ger(vec, vec)
NumPy einsum:np.einsum("i, j -> ij", vec, vec)
In [73]: torch.einsum('i, j -> ij', vec, vec)
Out[73]:
tensor([[0, 0, 0, 0],
[0, 1, 2, 3],
[0, 2, 4, 6],
[0, 3, 6, 9]])
8) 内积(向量)
火炬:torch.dot(vec1, vec2)
NumPy einsum:np.einsum("i, i -> ", vec1, vec2)
In [76]: torch.einsum('i, i -> ', vec, vec)
Out[76]: tensor(14)
9) 沿轴求和 0
火炬:torch.sum(aten, 0)
NumPy einsum:np.einsum("ij -> j", arr)
In [85]: torch.einsum('ij -> j', aten)
Out[85]: tensor([104, 108, 112, 116])
10) 沿轴 1 求和
火炬:torch.sum(aten, 1)
NumPy einsum:np.einsum("ij -> i", arr)
In [86]: torch.einsum('ij -> i', aten)
Out[86]: tensor([ 50, 90, 130, 170])
11) 批量矩阵乘法
火炬:torch.bmm(batch_tensor_1, batch_tensor_2)
麻木:np.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)
# input batch tensors to work with
In [13]: batch_tensor_1 = torch.arange(2 * 4 * 3).reshape(2, 4, 3)
In [14]: batch_tensor_2 = torch.arange(2 * 3 * 4).reshape(2, 3, 4)
In [15]: torch.bmm(batch_tensor_1, batch_tensor_2)
Out[15]:
tensor([[[ 20, 23, 26, 29],
[ 56, 68, 80, 92],
[ 92, 113, 134, 155],
[ 128, 158, 188, 218]],
[[ 632, 671, 710, 749],
[ 776, 824, 872, 920],
[ 920, 977, 1034, 1091],
[1064, 1130, 1196, 1262]]])
# sanity check with the shapes
In [16]: torch.bmm(batch_tensor_1, batch_tensor_2).shape
Out[16]: torch.Size([2, 4, 4])
# batch matrix multiply using einsum
In [17]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2)
Out[17]:
tensor([[[ 20, 23, 26, 29],
[ 56, 68, 80, 92],
[ 92, 113, 134, 155],
[ 128, 158, 188, 218]],
[[ 632, 671, 710, 749],
[ 776, 824, 872, 920],
[ 920, 977, 1034, 1091],
[1064, 1130, 1196, 1262]]])
# sanity check with the shapes
In [18]: torch.einsum("bij, bjk -> bik", batch_tensor_1, batch_tensor_2).shape
12) 沿轴 2 求和
火炬:torch.sum(batch_ten, 2)
NumPy einsum:np.einsum("ijk -> ij", arr3D)
In [99]: torch.einsum("ijk -> ij", batch_ten)
Out[99]:
tensor([[ 50, 90, 130, 170],
[ 4, 8, 12, 16]])
13) 对nD张量中的所有元素求和
火炬:torch.sum(batch_ten)
NumPy einsum:np.einsum("ijk -> ", arr3D)
In [101]: torch.einsum("ijk -> ", batch_ten)
Out[101]: tensor(480)
14) 多轴求和(即边缘化)
火炬:torch.sum(arr, dim=(dim0, dim1, dim2, dim3, dim4, dim6, dim7))
NumPy: np.einsum("ijklmnop -> n", nDarr)
# 8D tensor
In [103]: nDten = torch.randn((3,5,4,6,8,2,7,9))
In [104]: nDten.shape
Out[104]: torch.Size([3, 5, 4, 6, 8, 2, 7, 9])
# marginalize out dimension 5 (i.e. "n" here)
In [111]: esum = torch.einsum("ijklmnop -> n", nDten)
In [112]: esum
Out[112]: tensor([ 98.6921, -206.0575])
# marginalize out axis 5 (i.e. sum over rest of the axes)
In [113]: tsum = torch.sum(nDten, dim=(0, 1, 2, 3, 4, 6, 7))
In [115]: torch.allclose(tsum, esum)
Out[115]: True
15) 双点积/Frobenius inner product(等同于:torch.sum(hadamard-product)cf. 3)
火炬:torch.sum(aten * bten)
麻木:np.einsum("ij, ij -> ", arr1, arr2)
In [120]: torch.einsum("ij, ij -> ", aten, bten)
Out[120]: tensor(1300)