PyTorch:PyTorch 中的 numpy.linalg.multi_dot() 等价物是什么

PyTorch: What is numpy.linalg.multi_dot() equivalent in PyTorch

我正在尝试在 PyTorch 中执行多个矩阵的矩阵乘法,想知道 PyTorch 中 numpy.linalg.multi_dot() 的等价物是什么?

如果没有,我可以在 PyTorch 中执行此操作的下一个最佳方法是什么(在速度和内存方面)?

代码:

import numpy as np
import torch

A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
C = np.random.rand(3, 3)

results = np.linalg.multi_dot(A, B, C)

A_tsr = torch.tensor(A)
B_tsr = torch.tensor(B)
C_tsr = torch.tensor(C)

# What is the PyTorch equivalent of np.linalg.multi_dot()?

非常感谢!

~~看起来可以将张量发送到 multi_dot~~

看起来 numpy 实现将所有内容都转换为 numpy 数组。如果您的张量在 cpu 上并且分离,这应该可以工作。否则,转换为 numpy 将失败。

所以总的来说 - 可能没有其他选择。我认为你最好的选择是采用 multi_dot 实现,例如from here for numpy v1.19.0 并调整它以处理张量/跳过强制转换为 numpy。鉴于相似的界面和简单的代码,我认为这应该非常简单。