PyTorch:如何使用“torch.einsum()”查找嵌套张量与另一个张量的点积之间的踪迹

PyTorch: How to use `torch.einsum()` to find the trace between the dot product of a nested tensor and another tensor

假设我有一个嵌套张量 A:

import torch.nn as nn
X = np.array([[1, 3, 2], [2, 3, 5], [1, 2, 3]])
X = torch.DoubleTensor(X)

rows = X.shape[0]
cols = X.shape[1]

A = torch.matmul(X.view(rows, cols, 1),
                 X.view(rows, 1, cols))

A

输出:

tensor([[[ 1.,  3.,  2.],
         [ 3.,  9.,  6.],
         [ 2.,  6.,  4.]],

        [[ 4.,  6., 10.],
         [ 6.,  9., 15.],
         [10., 15., 25.]],

        [[ 1.,  2.,  3.],
         [ 2.,  4.,  6.],
         [ 3.,  6.,  9.]]], dtype=torch.float64)

我还有另一个张量 B:

B = torch.DoubleTensor([[11., 21, 31], [31, 51, 31], [41, 51, 21]])
B

输出:

tensor([[11., 21., 31.],
        [31., 51., 31.],
        [41., 51., 21.]])

如何使用 torch.einsum() 求出 A 中每个嵌套张量与张量 B 的点积之间的迹值。例如。 A:

中第一个嵌套张量之间点积的迹值
[[ 1.,  3.,  2.],
 [ 3.,  9.,  6.],
 [ 2.,  6.,  4.]]

和乙:

tensor([[11., 21., 31.],
        [31., 51., 31.],
        [41., 51., 21.]])

与 A 中的其他 2 个嵌套张量类似。

我的结果张量将是一个只有 3 个迹值的张量。有没有一种方法可以做到这一点而无需遍历 A 中的每个嵌套张量(比如 for 循环)?

Ps:

我知道找到 2 个张量的点积之间的迹值的代码是:

torch.einsum('ij,ji->', X, Y).item()

如果您知道如何使用 numpy.einsum() 执行此操作,请也告诉我。我可能只需要稍微调整 numpy.einsum() 即可使其适用于 PyTorch 张量。

很简单,需要加上A的'batch dimension':

torch.einsum('bij,ji->b', A, B)

输出为

tensor([1346., 3290., 1216.])