3D CNN 在图像序列上的输入形状应该是什么?
What should be the input shape for 3D CNN on a sequence of images?
https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#conv3d说明在3D CNN上做卷积的输入是(N,Cin,D,H,W)。想象一下,如果我有一系列图像要传递给 3D CNN。我说得对吗:
- N -> 序列数(小批量)
- Cin -> 通道数(rgb 为 3)
- D -> 序列中的图像数量
- H -> 序列中一张图像的高度
- W -> 序列中一张图像的宽度
我问的原因是当我堆叠图像张量时:a = torch.stack([img1, img2, img3, img4, img5])
我得到 torch.Size([5, 3, 396, 247])
的形状,所以是否必须将我的张量重塑为 torch.Size([3, 5, 396, 247])
所以该数量的通道将首先进入,或者在 Dataloader 内部无关紧要?
请注意,Dataloader 会自动添加一个与 N 相对应的维度。
是的,这很重要,您需要确保尺寸排序正确(假设您使用 DataLoader
的默认整理功能)。一种方法是使用 dim=1
而不是默认的 dim=0
来调用 torch.stack
。例如
a = torch.stack([img1, img2, img3, img4, img5], dim=1)
导致 a
成为 [3, 5, 396, 247]
的所需形状。
https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#conv3d说明在3D CNN上做卷积的输入是(N,Cin,D,H,W)。想象一下,如果我有一系列图像要传递给 3D CNN。我说得对吗:
- N -> 序列数(小批量)
- Cin -> 通道数(rgb 为 3)
- D -> 序列中的图像数量
- H -> 序列中一张图像的高度
- W -> 序列中一张图像的宽度
我问的原因是当我堆叠图像张量时:a = torch.stack([img1, img2, img3, img4, img5])
我得到 torch.Size([5, 3, 396, 247])
的形状,所以是否必须将我的张量重塑为 torch.Size([3, 5, 396, 247])
所以该数量的通道将首先进入,或者在 Dataloader 内部无关紧要?
请注意,Dataloader 会自动添加一个与 N 相对应的维度。
是的,这很重要,您需要确保尺寸排序正确(假设您使用 DataLoader
的默认整理功能)。一种方法是使用 dim=1
而不是默认的 dim=0
来调用 torch.stack
。例如
a = torch.stack([img1, img2, img3, img4, img5], dim=1)
导致 a
成为 [3, 5, 396, 247]
的所需形状。