torch.einsum 如何执行这个 4D 张量乘法?

How does torch.einsum perform this 4D tensor multiplication?

我遇到了使用 torch.einsum 计算张量乘法的代码。 I am able to understand the workings for lower order tensors,但是,不适用于如下所示的 4D 张量:

import torch

a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))

c = torch.einsum('nxhd,nyhd->nhxy', [a,b])

print(c.size())

# output: torch.Size([3, 2, 5, 4])

我需要以下方面的帮助:

  1. 这里执行的操作是什么(解释矩阵是如何 multiplied/transposed 等)?
  2. 在这种情况下 torch.einsum 真的有用吗?

(如果您只想分解 einsum 中涉及的步骤,请跳至 tl;dr 部分)

我将尝试解释 einsum 如何针对此示例逐步工作,但我将使用 numpy.einsumdocumentation 而不是使用 torch.einsum ),它的作用完全相同,但总的来说,我只是更喜欢它 table。尽管如此,同样的步骤也会发生在 torch 上。

让我们用 NumPy 重写上面的代码 -

import numpy as np

a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape

#(3, 2, 5, 4)

循序渐进np.einsum

Einsum 由 3 个步骤组成:multiplysumtranspose

让我们看看我们的尺寸。我们有一个 (3, 5, 2, 10) 和一个 (3, 4, 2, 10) 我们需要根据 'nxhd,nyhd->nhxy'

带到 (3, 2, 5, 4)

1。乘

让我们不用担心 n,x,y,h,d 轴的顺序,如果要保留它们或删除(减少)它们,只需担心这个事实。把它们写成 table 看看我们如何安排我们的维度 -

        ## Multiply ##
       n   x   y   h   d
      --------------------
a  ->  3   5       2   10
b  ->  3       4   2   10
c1 ->  3   5   4   2   10

要使 xy 轴之间的广播乘法得到 (x, y),我们必须在正确的位置添加一个新轴然后相乘。

a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)

c1 = a1*b1
c1.shape

#(3, 5, 4, 2, 10)  #<-- (n, x, y, h, d)

2。总和/减少

接下来,我们要减少最后一个轴 10。这将得到尺寸 (n,x,y,h)

          ## Reduce ##
        n   x   y   h   d
       --------------------
c1  ->  3   5   4   2   10
c2  ->  3   5   4   2

这很简单。让我们在 axis=-1

上做 np.sum
c2 = np.sum(c1, axis=-1)
c2.shape

#(3,5,4,2)  #<-- (n, x, y, h)

3。转置

最后一步是使用转置重新排列轴。为此,我们可以使用 np.transposenp.transpose(0,3,1,2) 基本上是在第 0 轴之后引入第 3 轴,并推动第 1 和第 2 轴。所以,(n,x,y,h) 变成 (n,h,x,y)

c3 = c2.transpose(0,3,1,2)
c3.shape

#(3,2,5,4)  #<-- (n, h, x, y)

4。最终检查

让我们做最后的检查,看看 c3 是否与从 np.einsum -

生成的 c 相同
np.allclose(c,c3)

#True

TL;DR.

因此,我们将 'nxhd , nyhd -> nhxy' 实现为 -

input     -> nxhd, nyhd
multiply  -> nxyhd      #broadcasting
sum       -> nxyh       #reduce
transpose -> nhxy

优势

np.einsum 相对于采取的多个步骤的优势在于,您可以选择执行计算所需的“路径”并使用相同的函数执行多个操作。这可以通过 optimize 参数来完成,这将优化 einsum 表达式的收缩顺序。

这些操作的非详尽列表,可以通过 einsum 计算,如下所示,并附有示例:

  • 数组的踪迹,numpy.trace
  • Return 一条对角线,numpy.diag.
  • 阵列轴求和,numpy.sum
  • 换位和排列,numpy.transpose
  • 矩阵乘法和点积,numpy.matmul numpy.dot
  • 向量内积和外积,numpy.inner numpy.outer.
  • 广播、逐元素和标量乘法,numpy.multiply
  • 张量收缩,numpy.tensordot
  • 链式数组运算,计算顺序低效,numpy.einsum_path.

基准

%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)

说明np.einsum运算速度比单个步骤快