PyTorch 中的非维转置

No N-dimensional tranpose in PyTorch

PyTorch 的 torch.transpose 函数仅转置 2D 输入。文档是 here.

另一方面,Tensorflow 的 tf.transpose 函数允许您转置 N 任意维度的张量。

有人可以解释为什么 PyTorch not/cannot 具有 N 维转置功能吗?这是由于 PyTorch 中计算图构造的动态特性与 Tensorflow 的 Define-then-运行 范式相比吗?

只是在pytorch中叫法不一样而已。 torch.Tensor.permute 将允许您在 pytorch 中交换尺寸,就像 tf.transpose 在 TensorFlow 中那样。

作为如何将 4D 图像张量从 NHWC 转换为 NCHW 的示例(未测试,因此可能包含错误):

>>> img_nhwc = torch.randn(10, 480, 640, 3)
>>> img_nhwc.size()
torch.Size([10, 480, 640, 3])
>>> img_nchw = img_nhwc.permute(0, 3, 1, 2)
>>> img_nchw.size()
torch.Size([10, 3, 480, 640])

Einops 支持任意维数的详细转置:

from einops import rearrange
x  = torch.zeros(10, 3, 100, 100)
y  = rearrange(x, 'b c h w -> b h w c')
x2 = rearrange(y, 'b h w c -> b c h w') # inverse to the first

(同样的代码也适用于 tensorfow)