Pytorch 中的批量矩阵乘法 - 与输出维度的处理混淆

Batch-Matrix multiplication in Pytorch - Confused with the handling of the output's dimension

我有两个数组:

A
B

数组A包含一批RGB图像,形状为:

[batch, Width, Height, 3]

而数组 B 包含对图像进行 "transformation-like" 操作所需的系数,形状为:

[batch, 4, 4, 3]

简单来说,对单张图片的操作就是乘法输出一张环境贴图(normalMap * Coefficients)。

我想要的输出应该保持形状:

[batch, Width, Height, 3]

我尝试使用 torch.bmm 但失败了。这有可能吗?

我认为你需要计算 PyTorch 与

BxCxHxW : number of mini-batches, channels, height, width

格式,并且还使用 matmul, since bmm 与张​​量一起使用或 ndim/dim/rank =3.

我知道你可以在网上找到这个,但无论如何:

batch1 = torch.randn(10, 3, 20, 10)
batch2 = torch.randn(10, 3, 10, 30)
res = torch.matmul(batch1, batch2)
res.size() # torch.Size([10, 3, 20, 30])