如何在 pytorch 图像处理模型中处理具有多个图像的样本?

How to handle samples with multiple images in a pytorch image processing model?

我的模型训练涉及对同一图像的多个变体进行编码,然后对图像的所有变体生成的表示求和。

数据加载器生成形状的张量批次:[batch_size,num_variants,1,height,width]1 对应图像颜色通道。

如何在 pytorch 中使用小批量训练我的模型? 我正在寻找一种适当的方法来通过网络转发所有 batch_size×num_variant 图像并对所有变体组的结果求和。

我当前的解决方案涉及展平前两个维度并执行 for 循环来对表示求和,但我觉得应该有更好的方法,但我不确定梯度是否会记住所有内容。

不确定我是否理解正确,但我想这就是您想要的(假设批处理图像张量称为 image):

Nb, Nv, inC, inH, inW = image.shape

# treat each variant as if it's an ordinary image in the batch
image = image.reshape(Nb*Nv, inC, inH, inW)

output = model(image)
_, outC, outH, outW = output.shape[1]

# reshapes the output such that dim==1 indicates variants
output = output.reshape(Nb, Nv, outC, outH, outW)

# summing over the variants and lose the dimension of summation, [Nb, outC, outH, outW]
output = output.sum(dim=1, keepdim=False)

我使用了 inCoutCinH 等,以防输入和输出 channels/sizes 不同。