了解图卷积的 Einsum 用法

Understanding an Einsum usage for graph convolution

我在这里阅读时空图卷积运算的代码: https://github.com/yysijie/st-gcn/blob/master/net/utils/tgcn.py 而且我在理解 einsum 操作发生了什么时遇到了一些困难。特别是

对于 x 形状为 (N, kernel_size, kc//kernel_size, t, v) 的张量,其中 kernel_size 通常是 3,假设 kc=64*kernel_sizet 是帧数,比如 64,v 是顶点数,比如 25。N 是批量大小。

现在对于形状为 (3, 25, 25) 的张量 A,其中每个维度都是图顶点上的过滤操作,einsum 计算为:

x = torch.einsum('nkctv,kvw->nctw', (x, A))

我不确定如何解释这个表达式。我认为它的意思是,对于每个批次元素,对于 64 个中的每个通道 c_i,将通过该通道的 (64, 25) 特征映射的矩阵乘法获得的三个矩阵中的每一个求和,其值为A[i]。我这个正确吗?这个表达式有点冗长,在符号方面似乎有点奇怪 kc 作为一个变量名的用法,但是 k 分解为内核大小和 c 作为 einsum 表达式中的通道数 (192//3 = 64)。任何见解表示赞赏。

Y = torch.einsum('nkctv,kvw->nctw', (x, A)) means:

einsum interpretation on graph

为了更好地理解,我将左侧的 x 替换为 Y

有助于您仔细查看符号:

  • nkctv 左侧
  • kvw右侧
  • nctw 是结果

结果中缺少的是:

  • k
  • v

这些元素被加在一起成为一个值并被压缩,留下最终的形状。

沿线的东西(扩展形状(添加 1s)被广播和每个元素的总和):

  • 左:(n, k, c, t, v, 1)
  • 右:(1, k, 1, 1, v, w)

现在开始(l,r 代表左和右):

  • torch.mul(l, r)
  • torch.sum(l, r, dim=(1, 4))
  • 挤压任何奇异维度

这很难得到,因此爱因斯坦的总结有助于思考结果形状相互“混合”,至少对我而言。