PyTorch 中复数的矩阵乘法
matrix multiplication for complex numbers in PyTorch
我正在尝试在 PyTorch 中将两个复数矩阵相乘,看起来 the torch.matmul functions is not added yet to PyTorch library for complex numbers.
您有什么推荐或者有其他方法可以在 PyTorch 中乘复数矩阵吗?
我使用 torch.mv 为 pytorch.matmul 实现了复数的这个函数,它对 time-being 工作正常:
def matmul_complex(t1, t2):
m = list(t1.size())[0]
n = list(t2.size())[1]
t = torch.empty((1,n), dtype=torch.cfloat)
t_total = torch.empty((m,n), dtype=torch.cfloat)
for i in range(0,n):
if i == 0:
t_total = torch.mv(t1,t2[:,i])
else:
t_total = torch.cat((t_total, torch.mv(t1,t2[:,i])), 0)
t_final = torch.reshape(t_total, (m,n))
return t_final
我是PyTorch新手,如有错误请指正。
目前 torch.matmul
不支持诸如 ComplexFloatTensor
之类的复杂张量,但您可以像下面的代码那样做一些紧凑的事情:
def matmul_complex(t1,t2):
return torch.view_as_complex(torch.stack((t1.real @ t2.real - t1.imag @ t2.imag, t1.real @ t2.imag + t1.imag @ t2.real),dim=2))
尽可能避免使用 for 循环,因为这会导致执行速度变慢。
矢量化是通过使用 built-in 方法实现的,如我所附的代码所示。
例如,对于 2 个维度为 1000 X 1000 的随机复杂矩阵,您的代码在 CPU 上大约需要 6.1 秒,而矢量化版本仅需要 101 毫秒(快约 60 倍)。
更新:
从 PyTorch 1.7.0 开始(如@EduardoReis 所述),您可以在复数矩阵之间进行矩阵乘法,类似于 real-valued 矩阵,如下所示:
t1 @ t2
(对于 t1
、t2
复数矩阵)。
我正在尝试在 PyTorch 中将两个复数矩阵相乘,看起来 the torch.matmul functions is not added yet to PyTorch library for complex numbers.
您有什么推荐或者有其他方法可以在 PyTorch 中乘复数矩阵吗?
我使用 torch.mv 为 pytorch.matmul 实现了复数的这个函数,它对 time-being 工作正常:
def matmul_complex(t1, t2):
m = list(t1.size())[0]
n = list(t2.size())[1]
t = torch.empty((1,n), dtype=torch.cfloat)
t_total = torch.empty((m,n), dtype=torch.cfloat)
for i in range(0,n):
if i == 0:
t_total = torch.mv(t1,t2[:,i])
else:
t_total = torch.cat((t_total, torch.mv(t1,t2[:,i])), 0)
t_final = torch.reshape(t_total, (m,n))
return t_final
我是PyTorch新手,如有错误请指正。
目前 torch.matmul
不支持诸如 ComplexFloatTensor
之类的复杂张量,但您可以像下面的代码那样做一些紧凑的事情:
def matmul_complex(t1,t2):
return torch.view_as_complex(torch.stack((t1.real @ t2.real - t1.imag @ t2.imag, t1.real @ t2.imag + t1.imag @ t2.real),dim=2))
尽可能避免使用 for 循环,因为这会导致执行速度变慢。 矢量化是通过使用 built-in 方法实现的,如我所附的代码所示。 例如,对于 2 个维度为 1000 X 1000 的随机复杂矩阵,您的代码在 CPU 上大约需要 6.1 秒,而矢量化版本仅需要 101 毫秒(快约 60 倍)。
更新:
从 PyTorch 1.7.0 开始(如@EduardoReis 所述),您可以在复数矩阵之间进行矩阵乘法,类似于 real-valued 矩阵,如下所示:
t1 @ t2
(对于 t1
、t2
复数矩阵)。