如何在pytorch中有效地计算两组不同大小的3D张量的距离矩阵?

How to efficiently calculate distance matrix in pytorch for two sets 3D tensors with different sizes?

我有形状为 BxCxHxW 的张量 X 和形状为 NxCxHxW 的张量 Y。 B 是批量大小,C 是通道,H 是高度,W 是宽度,N 对任何批次都是常数。基本上我想要一组 B 图像和另一组 N 图像之间距离的 BxN 距离矩阵。

我尝试使用 torch.cdist 将 X 重塑为 1xBx(C*H*W) 并将 Y 重塑为 1xNx(C*H*W) 通过取消压缩维度并展平最后 3 个通道,但我做到了使用此方法进行健全性检查并得到错误答案。

我要L2距离

根据 torch.cdist 的文档页面,两个输入和输出的形状如下:x1(B, P, M)x2(B, R, M),以及 output(B, P, R)

为了匹配您的情况:B=1P=BR=N,而 M=C*H*W 变平)。正如你刚才所解释的。

所以你基本上是为了:

>>> torch.cdist(X[None].flatten(2), Y[None].flatten(2))

如果你不服气,你可以用下面的方法查看:

>>> dist = []
>>> for x in X:
...    for y in Y:
...       dist.append((x-y).norm())

并将 torch.cdist 结果与 torch.tensor(dist).reshape(len(X), len(Y)) 进行比较。