PyTorch:PyTorch 中的 numpy.linalg.multi_dot() 等价物是什么
PyTorch: What is numpy.linalg.multi_dot() equivalent in PyTorch
我正在尝试在 PyTorch 中执行多个矩阵的矩阵乘法,想知道 PyTorch 中 numpy.linalg.multi_dot()
的等价物是什么?
如果没有,我可以在 PyTorch 中执行此操作的下一个最佳方法是什么(在速度和内存方面)?
代码:
import numpy as np
import torch
A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
C = np.random.rand(3, 3)
results = np.linalg.multi_dot(A, B, C)
A_tsr = torch.tensor(A)
B_tsr = torch.tensor(B)
C_tsr = torch.tensor(C)
# What is the PyTorch equivalent of np.linalg.multi_dot()?
非常感谢!
~~看起来可以将张量发送到 multi_dot~~
看起来 numpy 实现将所有内容都转换为 numpy 数组。如果您的张量在 cpu 上并且分离,这应该可以工作。否则,转换为 numpy 将失败。
所以总的来说 - 可能没有其他选择。我认为你最好的选择是采用 multi_dot
实现,例如from here for numpy v1.19.0 并调整它以处理张量/跳过强制转换为 numpy。鉴于相似的界面和简单的代码,我认为这应该非常简单。
我正在尝试在 PyTorch 中执行多个矩阵的矩阵乘法,想知道 PyTorch 中 numpy.linalg.multi_dot()
的等价物是什么?
如果没有,我可以在 PyTorch 中执行此操作的下一个最佳方法是什么(在速度和内存方面)?
代码:
import numpy as np
import torch
A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
C = np.random.rand(3, 3)
results = np.linalg.multi_dot(A, B, C)
A_tsr = torch.tensor(A)
B_tsr = torch.tensor(B)
C_tsr = torch.tensor(C)
# What is the PyTorch equivalent of np.linalg.multi_dot()?
非常感谢!
~~看起来可以将张量发送到 multi_dot~~
看起来 numpy 实现将所有内容都转换为 numpy 数组。如果您的张量在 cpu 上并且分离,这应该可以工作。否则,转换为 numpy 将失败。
所以总的来说 - 可能没有其他选择。我认为你最好的选择是采用 multi_dot
实现,例如from here for numpy v1.19.0 并调整它以处理张量/跳过强制转换为 numpy。鉴于相似的界面和简单的代码,我认为这应该非常简单。