torch / np einsum 在内部究竟是如何工作的

How exactly does torch / np einsum work internally

这是关于 GPU 中 torch.einsum 内部工作的查询。我知道如何使用 einsum。它是执行所有可能的矩阵乘法,并只挑出相关的乘法,还是只执行所需的计算?

例如,考虑两个张量 ab,形状为 (N,P),我希望找到每个对应张量 ni 的点积,形状 (1,P)。 使用einsum,代码为:

torch.einsum('ij,ij->i',a,b)

在不使用 einsum 的情况下,另一种获取输出的方法是:

torch.diag(a @ b.t())

现在,第二个代码应该比第一个代码执行更多的计算(例如,如果 N = 2000,它执行 2000 倍的计算)。但是,当我尝试对这两个操作进行计时时,它们完成所需的时间大致相同,这就引出了一个问题。 einsum是否执行所有组合(如第二个代码),并挑选出相关值?

要测试的示例代码:

import time
import torch
for i in range(100):
  a = torch.rand(50000, 256).cuda()
  b = torch.rand(50000, 256).cuda()

  t1 = time.time()
  val = torch.diag(a @ b.t())
  t2 = time.time()
  val2 = torch.einsum('ij,ij->i',a,b)
  t3 = time.time()
  print(t2-t1,t3-t2, torch.allclose(val,val2))

这可能与GPU可以并行计算a @ b.t()有关。这意味着 GPU 实际上不必等待每个 row-column 乘法计算完成才能计算下一个乘法。 如果你检查 CPU 然后你会发现 torch.diag(a @ b.t()) 对于大的 ab.

torch.einsum('ij,ij->i',a,b) 慢得多

我不能代表 torch,但几年前曾在一些细节上与 np.einsum 合作过。然后它根据索引字符串构造一个自定义迭代器,只进行必要的计算。从那时起,它以各种方式进行了重新设计,并且在可能的情况下显然将问题转换为 @,从而利用了 BLAS(等)库调用。

In [147]: a = np.arange(12).reshape(3,4)
In [148]: b = a

In [149]: np.einsum('ij,ij->i', a,b)
Out[149]: array([ 14, 126, 366])

我不能确定在这种情况下使用什么方法。使用 'j' 求和,也可以用:

In [150]: (a*b).sum(axis=1)
Out[150]: array([ 14, 126, 366])

如您所见,最简单的 dot 创建一个更大的数组,我们可以从中拉出对角线:

In [151]: (a@b.T).shape
Out[151]: (3, 3)

但这不是使用 @ 的正确方法。 @ 通过提供高效的 'batch' 处理扩展了 np.dot。所以 i 维度是批次的,jdot 的。

In [152]: a[:,None,:]@b[:,:,None]
Out[152]: 
array([[[ 14]],

       [[126]],

       [[366]]])
In [156]: (a[:,None,:]@b[:,:,None])[:,0,0]
Out[156]: array([ 14, 126, 366])

换句话说,它使用 (3,1,4) 和 (3,4,1) 生成 (3,1,1),在共享大小 4 维度上进行乘积求和。

部分采样时间:

In [162]: timeit np.einsum('ij,ij->i', a,b)
7.07 µs ± 89.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [163]: timeit (a*b).sum(axis=1)
9.89 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [164]: timeit np.diag(a@b.T)
10.6 µs ± 31.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [165]: timeit (a[:,None,:]@b[:,:,None])[:,0,0]
5.18 µs ± 197 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)