找不到 Pytorch 广播命令

Pytorch broadcasting command not found

我的代码中有以下嵌套 for 循环片段。嵌套循环正在减慢我的完整执行速度。

对于形状为 [batchSize,nClass*repeat] 的火炬张量 extended_output 和另一个维度为 [batchSize,nClass] 的火炬张量,我希望聚合发生如下:

for q in range(nClass):
    for u in range(repeat):
        output[:,q]=output[:,q]+extended_output[:,(q+u*nClass)]

这里,nClassrepeat都是整型变量,取值分别为14008

是否可以使用 pytorch 广播来避免这种嵌套 for 循环?任何帮助都将非常有用。

示例工作 cpode 可能是这样的

import torch
nClass=1400
repeat=8
batchSize=64
output=torch.zeros([batchSize,nClass])
extended_output=torch.rand([batchSize,nClass*repeat])

for q in range(nClass):
    for u in range(repeat):
        output[:,q]=output[:,q]+extended_output[:,(q+u*nClass)]

对于简短且可能过于简化的示例,我们深表歉意。我担心更大的会更难想象。但我希望这适合你的目的。这是我会做的:

import torch
nClass    = 3
repeat    = 2
batchSize = 4

torch.manual_seed(0)

output          = torch.zeros([batchSize,nClass])
extended_output = torch.rand([batchSize,nClass*repeat])


for q in range(nClass):
    for u in range(repeat):
        output[:,q]=output[:,q]+extended_output[:,(q+u*nClass)]

idxs = (torch.arange(repeat)*nClass).unsqueeze(0)
idxs = idxs + torch.arange(nClass).unsqueeze(1)
output_vectorized = extended_output[:, idxs].sum(2)

输出:

extended_output = 
tensor([[0.4963, 0.7682, 0.0885, 0.1320, 0.3074, 0.6341],
        [0.4901, 0.8964, 0.4556, 0.6323, 0.3489, 0.4017],
        [0.0223, 0.1689, 0.2939, 0.5185, 0.6977, 0.8000],
        [0.1610, 0.2823, 0.6816, 0.9152, 0.3971, 0.8742]])
output = 
tensor([[0.6283, 1.0756, 0.7226],
        [1.1224, 1.2453, 0.8573],
        [0.5408, 0.8665, 1.0939],
        [1.0762, 0.6794, 1.5558]])
output_vectorized = 
tensor([[0.6283, 1.0756, 0.7226],
        [1.1224, 1.2453, 0.8573],
        [0.5408, 0.8665, 1.0939],
        [1.0762, 0.6794, 1.5558]])