torch.einsum 如何执行这个 4D 张量乘法?
How does torch.einsum perform this 4D tensor multiplication?
我遇到了使用 torch.einsum
计算张量乘法的代码。 I am able to understand the workings for lower order tensors,但是,不适用于如下所示的 4D 张量:
import torch
a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))
c = torch.einsum('nxhd,nyhd->nhxy', [a,b])
print(c.size())
# output: torch.Size([3, 2, 5, 4])
我需要以下方面的帮助:
- 这里执行的操作是什么(解释矩阵是如何 multiplied/transposed 等)?
- 在这种情况下
torch.einsum
真的有用吗?
(如果您只想分解 einsum 中涉及的步骤,请跳至 tl;dr 部分)
我将尝试解释 einsum
如何针对此示例逐步工作,但我将使用 numpy.einsum
(documentation 而不是使用 torch.einsum
),它的作用完全相同,但总的来说,我只是更喜欢它 table。尽管如此,同样的步骤也会发生在 torch 上。
让我们用 NumPy 重写上面的代码 -
import numpy as np
a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape
#(3, 2, 5, 4)
循序渐进np.einsum
Einsum 由 3 个步骤组成:multiply
、sum
和 transpose
让我们看看我们的尺寸。我们有一个 (3, 5, 2, 10)
和一个 (3, 4, 2, 10)
我们需要根据 'nxhd,nyhd->nhxy'
带到 (3, 2, 5, 4)
1。乘
让我们不用担心 n,x,y,h,d
轴的顺序,如果要保留它们或删除(减少)它们,只需担心这个事实。把它们写成 table 看看我们如何安排我们的维度 -
## Multiply ##
n x y h d
--------------------
a -> 3 5 2 10
b -> 3 4 2 10
c1 -> 3 5 4 2 10
要使 x
和 y
轴之间的广播乘法得到 (x, y)
,我们必须在正确的位置添加一个新轴然后相乘。
a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)
c1 = a1*b1
c1.shape
#(3, 5, 4, 2, 10) #<-- (n, x, y, h, d)
2。总和/减少
接下来,我们要减少最后一个轴 10。这将得到尺寸 (n,x,y,h)
。
## Reduce ##
n x y h d
--------------------
c1 -> 3 5 4 2 10
c2 -> 3 5 4 2
这很简单。让我们在 axis=-1
上做 np.sum
c2 = np.sum(c1, axis=-1)
c2.shape
#(3,5,4,2) #<-- (n, x, y, h)
3。转置
最后一步是使用转置重新排列轴。为此,我们可以使用 np.transpose
。 np.transpose(0,3,1,2)
基本上是在第 0 轴之后引入第 3 轴,并推动第 1 和第 2 轴。所以,(n,x,y,h)
变成 (n,h,x,y)
c3 = c2.transpose(0,3,1,2)
c3.shape
#(3,2,5,4) #<-- (n, h, x, y)
4。最终检查
让我们做最后的检查,看看 c3 是否与从 np.einsum
-
生成的 c 相同
np.allclose(c,c3)
#True
TL;DR.
因此,我们将 'nxhd , nyhd -> nhxy'
实现为 -
input -> nxhd, nyhd
multiply -> nxyhd #broadcasting
sum -> nxyh #reduce
transpose -> nhxy
优势
np.einsum
相对于采取的多个步骤的优势在于,您可以选择执行计算所需的“路径”并使用相同的函数执行多个操作。这可以通过 optimize
参数来完成,这将优化 einsum 表达式的收缩顺序。
这些操作的非详尽列表,可以通过 einsum
计算,如下所示,并附有示例:
- 数组的踪迹,
numpy.trace
。
- Return 一条对角线,
numpy.diag
.
- 阵列轴求和,
numpy.sum
。
- 换位和排列,
numpy.transpose
。
- 矩阵乘法和点积,
numpy.matmul
numpy.dot
。
- 向量内积和外积,
numpy.inner
numpy.outer
.
- 广播、逐元素和标量乘法,
numpy.multiply
。
- 张量收缩,
numpy.tensordot
。
- 链式数组运算,计算顺序低效,
numpy.einsum_path
.
基准
%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
说明np.einsum
运算速度比单个步骤快
我遇到了使用 torch.einsum
计算张量乘法的代码。 I am able to understand the workings for lower order tensors,但是,不适用于如下所示的 4D 张量:
import torch
a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))
c = torch.einsum('nxhd,nyhd->nhxy', [a,b])
print(c.size())
# output: torch.Size([3, 2, 5, 4])
我需要以下方面的帮助:
- 这里执行的操作是什么(解释矩阵是如何 multiplied/transposed 等)?
- 在这种情况下
torch.einsum
真的有用吗?
(如果您只想分解 einsum 中涉及的步骤,请跳至 tl;dr 部分)
我将尝试解释 einsum
如何针对此示例逐步工作,但我将使用 numpy.einsum
(documentation 而不是使用 torch.einsum
),它的作用完全相同,但总的来说,我只是更喜欢它 table。尽管如此,同样的步骤也会发生在 torch 上。
让我们用 NumPy 重写上面的代码 -
import numpy as np
a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape
#(3, 2, 5, 4)
循序渐进np.einsum
Einsum 由 3 个步骤组成:multiply
、sum
和 transpose
让我们看看我们的尺寸。我们有一个 (3, 5, 2, 10)
和一个 (3, 4, 2, 10)
我们需要根据 'nxhd,nyhd->nhxy'
(3, 2, 5, 4)
1。乘
让我们不用担心 n,x,y,h,d
轴的顺序,如果要保留它们或删除(减少)它们,只需担心这个事实。把它们写成 table 看看我们如何安排我们的维度 -
## Multiply ##
n x y h d
--------------------
a -> 3 5 2 10
b -> 3 4 2 10
c1 -> 3 5 4 2 10
要使 x
和 y
轴之间的广播乘法得到 (x, y)
,我们必须在正确的位置添加一个新轴然后相乘。
a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)
c1 = a1*b1
c1.shape
#(3, 5, 4, 2, 10) #<-- (n, x, y, h, d)
2。总和/减少
接下来,我们要减少最后一个轴 10。这将得到尺寸 (n,x,y,h)
。
## Reduce ##
n x y h d
--------------------
c1 -> 3 5 4 2 10
c2 -> 3 5 4 2
这很简单。让我们在 axis=-1
np.sum
c2 = np.sum(c1, axis=-1)
c2.shape
#(3,5,4,2) #<-- (n, x, y, h)
3。转置
最后一步是使用转置重新排列轴。为此,我们可以使用 np.transpose
。 np.transpose(0,3,1,2)
基本上是在第 0 轴之后引入第 3 轴,并推动第 1 和第 2 轴。所以,(n,x,y,h)
变成 (n,h,x,y)
c3 = c2.transpose(0,3,1,2)
c3.shape
#(3,2,5,4) #<-- (n, h, x, y)
4。最终检查
让我们做最后的检查,看看 c3 是否与从 np.einsum
-
np.allclose(c,c3)
#True
TL;DR.
因此,我们将 'nxhd , nyhd -> nhxy'
实现为 -
input -> nxhd, nyhd
multiply -> nxyhd #broadcasting
sum -> nxyh #reduce
transpose -> nhxy
优势
np.einsum
相对于采取的多个步骤的优势在于,您可以选择执行计算所需的“路径”并使用相同的函数执行多个操作。这可以通过 optimize
参数来完成,这将优化 einsum 表达式的收缩顺序。
这些操作的非详尽列表,可以通过 einsum
计算,如下所示,并附有示例:
- 数组的踪迹,
numpy.trace
。 - Return 一条对角线,
numpy.diag
. - 阵列轴求和,
numpy.sum
。 - 换位和排列,
numpy.transpose
。 - 矩阵乘法和点积,
numpy.matmul
numpy.dot
。 - 向量内积和外积,
numpy.inner
numpy.outer
. - 广播、逐元素和标量乘法,
numpy.multiply
。 - 张量收缩,
numpy.tensordot
。 - 链式数组运算,计算顺序低效,
numpy.einsum_path
.
基准
%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
说明np.einsum
运算速度比单个步骤快