Pytorch矩阵乘法

Pytorch matrix multiplication

我正在为 pytorch 中的维度和矩阵乘法而苦苦挣扎。 我想乘以矩阵 A

tensor([[[104.7500, 111.3750, 138.2500, 144.8750],
         [104.2500, 110.8750, 137.7500, 144.3750]],

        [[356.8750, 363.5000, 390.3750, 397.0000],
         [356.3750, 363.0000, 389.8750, 396.5000]]])

矩阵 B

tensor([[[[  0.,   1.,   2.,   5.,   6.,   7.,  10.,  11.,  12.],
          [  2.,   3.,   4.,   7.,   8.,   9.,  12.,  13.,  14.],
          [ 10.,  11.,  12.,  15.,  16.,  17.,  20.,  21.,  22.],
          [ 12.,  13.,  14.,  17.,  18.,  19.,  22.,  23.,  24.]],

         [[ 25.,  26.,  27.,  30.,  31.,  32.,  35.,  36.,  37.],
          [ 27.,  28.,  29.,  32.,  33.,  34.,  37.,  38.,  39.],
          [ 35.,  36.,  37.,  40.,  41.,  42.,  45.,  46.,  47.],
          [ 37.,  38.,  39.,  42.,  43.,  44.,  47.,  48.,  49.]],

         [[ 50.,  51.,  52.,  55.,  56.,  57.,  60.,  61.,  62.],
          [ 52.,  53.,  54.,  57.,  58.,  59.,  62.,  63.,  64.],
          [ 60.,  61.,  62.,  65.,  66.,  67.,  70.,  71.,  72.],
          [ 62.,  63.,  64.,  67.,  68.,  69.,  72.,  73.,  74.]]],


        [[[ 75.,  76.,  77.,  80.,  81.,  82.,  85.,  86.,  87.],
          [ 77.,  78.,  79.,  82.,  83.,  84.,  87.,  88.,  89.],
          [ 85.,  86.,  87.,  90.,  91.,  92.,  95.,  96.,  97.],
          [ 87.,  88.,  89.,  92.,  93.,  94.,  97.,  98.,  99.]],

         [[100., 101., 102., 105., 106., 107., 110., 111., 112.],
          [102., 103., 104., 107., 108., 109., 112., 113., 114.],
          [110., 111., 112., 115., 116., 117., 120., 121., 122.],
          [112., 113., 114., 117., 118., 119., 122., 123., 124.]],

         [[125., 126., 127., 130., 131., 132., 135., 136., 137.],
          [127., 128., 129., 132., 133., 134., 137., 138., 139.],
          [135., 136., 137., 140., 141., 142., 145., 146., 147.],
          [137., 138., 139., 142., 143., 144., 147., 148., 149.]]]])

然而,使用简单的 @ 将它们相乘,并没有让我得到想要的结果。 我想要的是这样的东西:将 A 的前两行乘以 B 的前 3 个 4x9 子矩阵(比方说 B[:,:,0,:]),这样我就有两个结果,然后以同样的方式乘以A 的第三行和第四行与 B 的第二个 3 4x9 子矩阵,所以再次有两个结果,然后我想将每个乘法的第一个结果和每个乘法的第二个结果相加。 我知道我必须进行某种整形,但我发现它很混乱,你能帮我提供一个非常通用的解决方案吗?

这个例子会有所帮助:

a = torch.ones((4, 4)).long()
a = a.reshape(2, 2, 4)
b = torch.tensor(list(range(36*6)))
b = b.reshape(2, 3, 4, 9)

t1 = a[0] @ b[0, :]
t2 = a[1] @ b[1, :]
result = t1 + t2

accum = torch.zeros((b.shape[1], a.shape[1], b.shape[3]))
for i in range(a.shape[0]):
  accum = accum + (a[i] @ b[i, :])

万一有人想知道如何使用 torch.einsum 执行此操作,您只需考虑维度并使用下标显式操作:

>>> torch.einsum('ijk,ilkm->ljm', A, B)

执行的整体操作是 pseudo-code:

for i, j, k, l, m in IxJxKxLxM:
    out[l][j][m] += A[i][j][k]*B[i][l][k][m]