如何在不将所有行存储在内存中或迭代的情况下有效地乘以具有重复行的火炬张量?

How to efficiently multiply by torch tensor with repeated rows without storing all the rows in memory or iterating?

给定火炬张量:

# example tensor size 2 x 4
a = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])

还有一个每 n 行重复一次的地方:

# example tensor size 4 x 3 where every 2 rows repeated
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])

如何进行矩阵乘法:

>>> torch.mm(a, b)
tensor([[ 28.,  38.,  48.],
        [ 68.,  94., 120.]])

没有将整个重复行张量复制到内存或迭代?

即仅存储前 2 行:

# example tensor size 2 x 3 where only the first two rows from b are actually stored in memory
b_abbreviated = torch.Tensor([[1, 2, 3], [4, 5, 6]])

因为这些行会重复。

有函数

torch.expand()

但这在重复一行以上时确实有效,而且,如这个问题:

表明并且我自己的测试证实在调用时通常最终会将整个张量复制到内存中

.to(device)

也可以迭代执行此操作,但这相对较慢。

有没有什么方法可以在不将整个重复行张量存储在内存中的情况下有效地执行此操作?

编辑说明:

抱歉,最初没有澄清:一个被用作第一个张量的第一维以保持示例简单,但实际上我正在寻找任何两个张量 a 和 b 的一般情况的解决方案,使得它们的维度适合矩阵乘法,并且 b 的行每 n 行重复一次。我更新了示例以反映这一点。

假设 a 的第一个维度在您的示例中为 1,您可以执行以下操作:

a = torch.Tensor([[1, 2, 3, 4]])
b_abbreviated = torch.Tensor([[1, 2, 3], [4, 5, 6]])
torch.mm(a.reshape(-1, 2), b_abbreviated).sum(axis=0, keepdim=True)

在这里,不是重复行,而是将 a 乘以块,然后按列将它们相加以获得相同的结果。


如果a的第一个维度不一定是1,你可以试试下面的方法:

torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1).sum(
dim=0, keepdim=True).reshape(a.shape[0], -1)

在这里,您执行以下操作:

  • 使用 torch.mm(a.reshape(-1,2),b_abbreviated,您再次将 a 的每一行拆分为大小为 2 的块,并将它们一层层堆叠,然后将每一行堆叠在另一层之上。
  • 使用 torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]),然后将这些堆栈按行分开,以便拆分的每个结果组件对应于单行的块。
  • 然后 torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1) 这些堆栈按列连接。
  • .sum(dim=0, keepdim=True) 中,与 a 中各个行的单独块相对应的结果相加。
  • 对于 .reshape(a.shape[0], -1),按列连接的 a 行再次按行堆叠。

与直接矩阵乘法相比,它似乎相当慢,这并不奇怪,但我还没有检查与显式迭代相比。可能有更好的方法,如果我想到任何会编辑。