torch / np einsum 在内部究竟是如何工作的
How exactly does torch / np einsum work internally
这是关于 GPU 中 torch.einsum
内部工作的查询。我知道如何使用 einsum
。它是执行所有可能的矩阵乘法,并只挑出相关的乘法,还是只执行所需的计算?
例如,考虑两个张量 a
和 b
,形状为 (N,P)
,我希望找到每个对应张量 ni
的点积,形状 (1,P)
。
使用einsum,代码为:
torch.einsum('ij,ij->i',a,b)
在不使用 einsum 的情况下,另一种获取输出的方法是:
torch.diag(a @ b.t())
现在,第二个代码应该比第一个代码执行更多的计算(例如,如果 N
= 2000
,它执行 2000
倍的计算)。但是,当我尝试对这两个操作进行计时时,它们完成所需的时间大致相同,这就引出了一个问题。 einsum
是否执行所有组合(如第二个代码),并挑选出相关值?
要测试的示例代码:
import time
import torch
for i in range(100):
a = torch.rand(50000, 256).cuda()
b = torch.rand(50000, 256).cuda()
t1 = time.time()
val = torch.diag(a @ b.t())
t2 = time.time()
val2 = torch.einsum('ij,ij->i',a,b)
t3 = time.time()
print(t2-t1,t3-t2, torch.allclose(val,val2))
这可能与GPU可以并行计算a @ b.t()
有关。这意味着 GPU 实际上不必等待每个 row-column 乘法计算完成才能计算下一个乘法。
如果你检查 CPU 然后你会发现 torch.diag(a @ b.t())
对于大的 a
和 b
.
比 torch.einsum('ij,ij->i',a,b)
慢得多
我不能代表 torch
,但几年前曾在一些细节上与 np.einsum
合作过。然后它根据索引字符串构造一个自定义迭代器,只进行必要的计算。从那时起,它以各种方式进行了重新设计,并且在可能的情况下显然将问题转换为 @
,从而利用了 BLAS(等)库调用。
In [147]: a = np.arange(12).reshape(3,4)
In [148]: b = a
In [149]: np.einsum('ij,ij->i', a,b)
Out[149]: array([ 14, 126, 366])
我不能确定在这种情况下使用什么方法。使用 'j' 求和,也可以用:
In [150]: (a*b).sum(axis=1)
Out[150]: array([ 14, 126, 366])
如您所见,最简单的 dot
创建一个更大的数组,我们可以从中拉出对角线:
In [151]: (a@b.T).shape
Out[151]: (3, 3)
但这不是使用 @
的正确方法。 @
通过提供高效的 'batch' 处理扩展了 np.dot
。所以 i
维度是批次的,j
是 dot
的。
In [152]: a[:,None,:]@b[:,:,None]
Out[152]:
array([[[ 14]],
[[126]],
[[366]]])
In [156]: (a[:,None,:]@b[:,:,None])[:,0,0]
Out[156]: array([ 14, 126, 366])
换句话说,它使用 (3,1,4) 和 (3,4,1) 生成 (3,1,1),在共享大小 4 维度上进行乘积求和。
部分采样时间:
In [162]: timeit np.einsum('ij,ij->i', a,b)
7.07 µs ± 89.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [163]: timeit (a*b).sum(axis=1)
9.89 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [164]: timeit np.diag(a@b.T)
10.6 µs ± 31.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [165]: timeit (a[:,None,:]@b[:,:,None])[:,0,0]
5.18 µs ± 197 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
这是关于 GPU 中 torch.einsum
内部工作的查询。我知道如何使用 einsum
。它是执行所有可能的矩阵乘法,并只挑出相关的乘法,还是只执行所需的计算?
例如,考虑两个张量 a
和 b
,形状为 (N,P)
,我希望找到每个对应张量 ni
的点积,形状 (1,P)
。
使用einsum,代码为:
torch.einsum('ij,ij->i',a,b)
在不使用 einsum 的情况下,另一种获取输出的方法是:
torch.diag(a @ b.t())
现在,第二个代码应该比第一个代码执行更多的计算(例如,如果 N
= 2000
,它执行 2000
倍的计算)。但是,当我尝试对这两个操作进行计时时,它们完成所需的时间大致相同,这就引出了一个问题。 einsum
是否执行所有组合(如第二个代码),并挑选出相关值?
要测试的示例代码:
import time
import torch
for i in range(100):
a = torch.rand(50000, 256).cuda()
b = torch.rand(50000, 256).cuda()
t1 = time.time()
val = torch.diag(a @ b.t())
t2 = time.time()
val2 = torch.einsum('ij,ij->i',a,b)
t3 = time.time()
print(t2-t1,t3-t2, torch.allclose(val,val2))
这可能与GPU可以并行计算a @ b.t()
有关。这意味着 GPU 实际上不必等待每个 row-column 乘法计算完成才能计算下一个乘法。
如果你检查 CPU 然后你会发现 torch.diag(a @ b.t())
对于大的 a
和 b
.
torch.einsum('ij,ij->i',a,b)
慢得多
我不能代表 torch
,但几年前曾在一些细节上与 np.einsum
合作过。然后它根据索引字符串构造一个自定义迭代器,只进行必要的计算。从那时起,它以各种方式进行了重新设计,并且在可能的情况下显然将问题转换为 @
,从而利用了 BLAS(等)库调用。
In [147]: a = np.arange(12).reshape(3,4)
In [148]: b = a
In [149]: np.einsum('ij,ij->i', a,b)
Out[149]: array([ 14, 126, 366])
我不能确定在这种情况下使用什么方法。使用 'j' 求和,也可以用:
In [150]: (a*b).sum(axis=1)
Out[150]: array([ 14, 126, 366])
如您所见,最简单的 dot
创建一个更大的数组,我们可以从中拉出对角线:
In [151]: (a@b.T).shape
Out[151]: (3, 3)
但这不是使用 @
的正确方法。 @
通过提供高效的 'batch' 处理扩展了 np.dot
。所以 i
维度是批次的,j
是 dot
的。
In [152]: a[:,None,:]@b[:,:,None]
Out[152]:
array([[[ 14]],
[[126]],
[[366]]])
In [156]: (a[:,None,:]@b[:,:,None])[:,0,0]
Out[156]: array([ 14, 126, 366])
换句话说,它使用 (3,1,4) 和 (3,4,1) 生成 (3,1,1),在共享大小 4 维度上进行乘积求和。
部分采样时间:
In [162]: timeit np.einsum('ij,ij->i', a,b)
7.07 µs ± 89.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [163]: timeit (a*b).sum(axis=1)
9.89 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [164]: timeit np.diag(a@b.T)
10.6 µs ± 31.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [165]: timeit (a[:,None,:]@b[:,:,None])[:,0,0]
5.18 µs ± 197 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)