使用 numpy 的批量张量乘法

Batched tensor multiplication with numpy

我正在尝试执行以下矩阵和张量乘法,但已批处理。

我有一个 x 向量列表:

x = np.array([[2.0, 2.0], [3.0, 3.0], [4.0, 4.0], [5.0, 5.0]])

以及以下矩阵和张量:

R = np.array(
    [
        [1.0, 1.0],
        [0.0, 1.0],
    ]
)
T = np.array(
    [
        [
            [2.0, 0.0],
            [0.0, 0.0],
        ],
        [
            [0.0, 0.0],
            [0.0, 2.0],
        ]
    ]
)

批量矩阵乘法相对简单:

x.dot(R.T)

但是我正在为第二部分而苦苦挣扎。

我尝试使用 tensordot 但到目前为止没有成功。我错过了什么?

您可以或多或少地直接将您的公式转换为 einsum:

>>> np.einsum('ijk,lj,lk->li', T, x, x)
array([[ 8.,  8.],
       [18., 18.],
       [32., 32.],
       [50., 50.]])

仅使用 @:

>>> ((x[:, None, None, :]@T).squeeze()@x[..., None]).squeeze()
array([[ 8.,  8.],
       [18., 18.],
       [32., 32.],
       [50., 50.]])

或混合型:

>>> np.einsum('ijl,lj->li', T@x.T, x)
array([[ 8.,  8.],
       [18., 18.],
       [32., 32.],
       [50., 50.]])

我们可以结合使用 tensor matrix-multiplicationnp.tensordot and einsum 基本上分两步完成 -

Tx = np.tensordot(T,x,axes=((1),(1)))
out = np.einsum('ikl,lk->li',Tx,x)

基准测试

根据 OP 的评论设置:

In [1]: import numpy as np

In [2]: x = np.random.rand(1000000,6)

In [3]: T = np.random.rand(6,6,6)

计时 -

# @Han Altae-Tran's soln
In [4]: %%timeit
   ...: W = np.matmul(T,x.T) 
   ...: ZT = np.sum(W*x.T[np.newaxis,:,:], axis=1).T
   ...: 
1 loops, best of 3: 496 ms per loop

# @Paul Panzer's soln-1
In [5]: %timeit np.einsum('ijk,lj,lk->li', T, x, x)
1 loops, best of 3: 831 ms per loop

# @Paul Panzer's soln-2
In [6]: %timeit ((x[:, None, None, :]@T).squeeze()@x[..., None]).squeeze()
1 loops, best of 3: 1.39 s per loop

# @Paul Panzer's soln-3
In [7]: %timeit np.einsum('ijl,lj->li', T@x.T, x)
1 loops, best of 3: 358 ms per loop

# From this post's soln
In [8]: %%timeit
   ...: Tx = np.tensordot(T,x,axes=((1),(1)))
   ...: out = np.einsum('ikl,lk->li',Tx,x)
   ...: 
1 loops, best of 3: 168 ms per loop

正如 Paul 所指出的,einsum 是完成任务的一种简单方法,但如果速度是一个问题,那么通常最好坚持使用典型的 numpy 函数。

这可以通过写出方程并将步骤转换为矩阵运算来实现。

X 为您要批处理的 m x d 数据矩阵,Z 为您想要的 m x d 结果。我们将到达 Z.T(转置),因为它更容易。

请注意,为了得出 R 贡献的等式,我们可以写成

然后我们可以将其作为一个 numpy 矩阵乘以 R.dot(X.T)

同样,观察 T 的贡献是

括号内是 TX.T 之间的批处理矩阵乘法。因此,如果我们将括号内的数量定义为

我们可以使用 W = np.matmul(T,X.T) 在 numpy 中得到它。继续我们的简化,我们看到 T 的贡献是

相当于np.sum(W*X.T[np.newaxis,:,:], axis=1)。将所有内容放在一起,我们最终得到

W = np.matmul(T,X.T) 
ZT = R.dot(X.T) + np.sum(W*X.T[np.newaxis,:,:], axis=1) 
Z = ZT.T

对于较大的批量大小,这比 d=2 时的 einsum 函数快大约 3-4 倍。如果我们要避免使用尽可能多的转置,它可能会更快一点。

由于缓存使用不是小张量序列的问题(就像大矩阵的一般点积一样),因此很容易用简单的循环来表述问题。

例子

import numba as nb
import numpy as np
import time

@nb.njit(fastmath=True,parallel=True)
def tensor_mult(T,x):
  res=np.empty((x.shape[0],T.shape[0]),dtype=T.dtype)
  for l in nb.prange(x.shape[0]):
    for i in range(T.shape[0]):
      sum=0.
      for j in range(T.shape[1]):
        for k in range(T.shape[2]):
          sum+=T[i,j,k]*x[l,j]*x[l,k]
      res[l,i]=sum
  return res

基准测试

x = np.random.rand(1000000,6)
T = np.random.rand(6,6,6)

#first call has a compilation overhead (about 0.6s)
res=tensor_mult(T,x)

t1=time.time()
for i in range(10):
  #@divakar
  #Tx = np.tensordot(T,x,axes=((1),(1)))
  #out = np.einsum('ikl,lk->li',Tx,x)

  res=tensor_mult(T,x)

print(time.time()-t1)

结果(4C/8T)

Divakars solution: 191ms
Simple loops: 62.4ms