仅保留对角线条目时执行大 dot/tensor 点积的最有效方法

Most efficient way to perform large dot/tensor dot products while only keeping diagonal entries

我正在尝试找出一种使用 numpy 以最省时的方式执行以下代数的方法:

给定一个形状为 (n, m, p) 的 3D matrix/tensor、A 和一个形状为 (n, p) 的 2D matrix/tensor、B,计算 C_ij = sum_over_k (A_ijk * B_ik),其中生成的矩阵 C 的维度为 (n, m)。

我已经尝试了两种方法来做到这一点。一种是循环遍历第一维,每次计算一个规则的点积。 另一种方法是用np.tensordot(A, B.T)计算出一个形状为(n, m, n)的结果,然后取1维和3维的对角线元素。两种方法如下所示。

第一种方法:

C = np.zeros((n,m))

for i in range(n):

  C[i] = np.dot(A[i], B[i])

第二种方法:

C = np.diagonal(np.tensordot(A, B.T, axes = 1), axis1=0, axis2=2).T

但是由于n是一个很大的数,第一种方法中n的循环比较耗时。第二种方法计算了太多不必要的条目来获得那个巨大的(n, m, n)矩阵,而且也花费了太多时间,我想知道是否有任何有效的方法可以做到这一点?

这是我的实现:

B = np.expand_dims(B, axis=1)
E = A * B
E = np.sum(E, axis=-1)

检查:

import numpy as np
n, m, p = 2, 2, 2
np.random.seed(0)
A = np.random.randint(1, 10, (n, m, p))
B = np.random.randint(1, 10, (n, p))

C = np.diagonal(np.tensordot(A, B.T, axes = 1), axis1=0, axis2=2).T

# from here is my implementation
B = np.expand_dims(B, axis=1)
E = A * B
E = np.sum(E, axis=-1)

print(np.array_equal(C, E))

True

使用 np.expand_dims() 添加新维度。 并使用广播相乘。最后,沿第三个维度求和。

感谢来自 user3483203

的验证码

定义 2 个数组:

In [168]: A = np.arange(2*3*4).reshape(2,3,4); B = np.arange(2*4).reshape(2,4)                               

您的迭代方法:

In [169]: [np.dot(a,b) for a,b in zip(A,B)]                                                                  
Out[169]: [array([14, 38, 62]), array([302, 390, 478])]

einsum 实际上是从你的 C_ij = sum_over_k (A_ijk * B_ik):

中写出来的
In [170]: np.einsum('ijk,ik->ij', A, B)                                                                      
Out[170]: 
array([[ 14,  38,  62],
       [302, 390, 478]])
添加了

@matmul 以执行批量 dot 产品;这里的 i 维度是批次一。由于A的最后一个和B的第2个到最后一个用于dot求和,我们不得不暂时将B扩展为(2,4,1)

In [171]: A@B[...,None]                                                                                      
Out[171]: 
array([[[ 14],
        [ 38],
        [ 62]],

       [[302],
        [390],
        [478]]])
In [172]: (A@B[...,None])[...,0]                                                                             
Out[172]: 
array([[ 14,  38,  62],
       [302, 390, 478]])

通常 matmul 是最快的,因为它像代码一样将任务传递给 BLAS。