我怎样才能更有效地将一批张量中的每个元素与除自身之外的所有其他批次元素相乘?
How can I more efficiently multiply every element in a batch of tensors with every other batch element, except for itself?
所以,我有这段代码可以将一批张量中的每个元素与除自身之外的所有其他元素相乘。该代码有效,但随着批量大小的增加,它变得非常慢(理想情况下,我希望能够将它用于高达 1000 或更多的批量大小,但即使是几百也可以)。当使用 PyTorch autograd 系统和大批量(如 50 或更大)时,它基本上会冻结。
我需要帮助使代码更快、更高效,同时仍获得相同的输出。如有任何帮助,我们将不胜感激!
import torch
tensor = torch.randn(50, 512, 512)
batch_size = tensor.size(0)
list1 = []
for i in range(batch_size):
list2 = []
for j in range(batch_size):
if j != i:
x_out = (tensor[i] * tensor[j]).sum()
list2.append(x_out )
list1.append(sum(list2))
out = sum(list1)
我认为 torch.prod
可能可以使用,但它似乎不会产生与上面代码相同的输出。 NumPy 答案是可以接受的,只要它们可以在 PyTorch 中重新创建。
您可以执行以下操作:
import torch
tensor = torch.randn(50, 512, 512)
batch_size = tensor.size(0)
tensor = tensor.reshape(batch_size, -1)
prod = torch.matmul(tensor, tensor.transpose(0,1))
out = torch.sum(prod) - torch.trace(prod)
在这里,您首先展平每个元素。然后,将每行都是一个元素的矩阵与其自己的转置相乘,得到一个 batch_size x batch_size
矩阵,其中第 ij
个元素等于 tensor[i]
与 [=14= 的乘积].因此,对该矩阵中的值求和并减去它的迹线(即对角线元素的总和)得到所需的结果。
两种方法我都试过了,batch_size
为1000,耗时从61.43秒降到0.59秒。
所以,我有这段代码可以将一批张量中的每个元素与除自身之外的所有其他元素相乘。该代码有效,但随着批量大小的增加,它变得非常慢(理想情况下,我希望能够将它用于高达 1000 或更多的批量大小,但即使是几百也可以)。当使用 PyTorch autograd 系统和大批量(如 50 或更大)时,它基本上会冻结。
我需要帮助使代码更快、更高效,同时仍获得相同的输出。如有任何帮助,我们将不胜感激!
import torch
tensor = torch.randn(50, 512, 512)
batch_size = tensor.size(0)
list1 = []
for i in range(batch_size):
list2 = []
for j in range(batch_size):
if j != i:
x_out = (tensor[i] * tensor[j]).sum()
list2.append(x_out )
list1.append(sum(list2))
out = sum(list1)
我认为 torch.prod
可能可以使用,但它似乎不会产生与上面代码相同的输出。 NumPy 答案是可以接受的,只要它们可以在 PyTorch 中重新创建。
您可以执行以下操作:
import torch
tensor = torch.randn(50, 512, 512)
batch_size = tensor.size(0)
tensor = tensor.reshape(batch_size, -1)
prod = torch.matmul(tensor, tensor.transpose(0,1))
out = torch.sum(prod) - torch.trace(prod)
在这里,您首先展平每个元素。然后,将每行都是一个元素的矩阵与其自己的转置相乘,得到一个 batch_size x batch_size
矩阵,其中第 ij
个元素等于 tensor[i]
与 [=14= 的乘积].因此,对该矩阵中的值求和并减去它的迹线(即对角线元素的总和)得到所需的结果。
两种方法我都试过了,batch_size
为1000,耗时从61.43秒降到0.59秒。