如何处理 Unet 架构 PyTorch 中的奇怪分辨率
How to handle odd resolutions in Unet architecture PyTorch
我正在 PyTorch 中实现基于 U-Net 的架构。在火车时间,我有大小 256x256
的补丁,不会造成任何问题。但是在测试时,我有全高清图像 (1920x1080
)。这会导致跳过连接期间出现问题。
下采样 1920x1080
3 次得到 240x135
。如果我再向下采样一次,分辨率变为 120x68
,当向上采样时,分辨率变为 240x136
。现在,我无法连接这两个特征图。我该如何解决?
PS:我认为这是一个相当普遍的问题,但我没有得到任何解决方案,甚至在网络上的任何地方都没有提到这个问题。我错过了什么吗?
在解码过程中经常涉及跳跃连接的分割网络中,这是一个非常普遍的问题。网络通常(取决于实际架构)需要边长为最大步幅(8、16、32 等)整数倍的输入大小。
主要有两种方式:
- 将输入调整为最接近的可行大小。
- 将输入填充到下一个更大的可行大小。
我更喜欢 (2),因为 (1) 会导致所有像素的像素级别发生微小变化,从而导致不必要的模糊。请注意,我们通常需要在这两种方法中恢复原始形状。
此任务我最喜欢的代码片段(height/width 的对称填充):
import torch
import torch.nn.functional as F
def pad_to(x, stride):
h, w = x.shape[-2:]
if h % stride > 0:
new_h = h + stride - h % stride
else:
new_h = h
if w % stride > 0:
new_w = w + stride - w % stride
else:
new_w = w
lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
pads = (lw, uw, lh, uh)
# zero-padding by default.
# See others at https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad
out = F.pad(x, pads, "constant", 0)
return out, pads
def unpad(x, pad):
if pad[2]+pad[3] > 0:
x = x[:,:,pad[2]:-pad[3],:]
if pad[0]+pad[1] > 0:
x = x[:,:,:,pad[0]:-pad[1]]
return x
一个测试片段:
x = torch.zeros(4, 3, 1080, 1920) # Raw data
x_pad, pads = pad_to(x, 16) # Padded data, feed this to your network
x_unpad = unpad(x_pad, pads) # Un-pad the network output to recover the original shape
print('Original: ', x.shape)
print('Padded: ', x_pad.shape)
print('Recovered: ', x_unpad.shape)
输出:
Original: torch.Size([4, 3, 1080, 1920])
Padded: torch.Size([4, 3, 1088, 1920])
Recovered: torch.Size([4, 3, 1080, 1920])
参考:https://github.com/seoungwugoh/STM/blob/905f11492a6692dd0d0fa395881a8ec09b211a36/helpers.py#L33
我正在 PyTorch 中实现基于 U-Net 的架构。在火车时间,我有大小 256x256
的补丁,不会造成任何问题。但是在测试时,我有全高清图像 (1920x1080
)。这会导致跳过连接期间出现问题。
下采样 1920x1080
3 次得到 240x135
。如果我再向下采样一次,分辨率变为 120x68
,当向上采样时,分辨率变为 240x136
。现在,我无法连接这两个特征图。我该如何解决?
PS:我认为这是一个相当普遍的问题,但我没有得到任何解决方案,甚至在网络上的任何地方都没有提到这个问题。我错过了什么吗?
在解码过程中经常涉及跳跃连接的分割网络中,这是一个非常普遍的问题。网络通常(取决于实际架构)需要边长为最大步幅(8、16、32 等)整数倍的输入大小。
主要有两种方式:
- 将输入调整为最接近的可行大小。
- 将输入填充到下一个更大的可行大小。
我更喜欢 (2),因为 (1) 会导致所有像素的像素级别发生微小变化,从而导致不必要的模糊。请注意,我们通常需要在这两种方法中恢复原始形状。
此任务我最喜欢的代码片段(height/width 的对称填充):
import torch
import torch.nn.functional as F
def pad_to(x, stride):
h, w = x.shape[-2:]
if h % stride > 0:
new_h = h + stride - h % stride
else:
new_h = h
if w % stride > 0:
new_w = w + stride - w % stride
else:
new_w = w
lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2)
lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2)
pads = (lw, uw, lh, uh)
# zero-padding by default.
# See others at https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.pad
out = F.pad(x, pads, "constant", 0)
return out, pads
def unpad(x, pad):
if pad[2]+pad[3] > 0:
x = x[:,:,pad[2]:-pad[3],:]
if pad[0]+pad[1] > 0:
x = x[:,:,:,pad[0]:-pad[1]]
return x
一个测试片段:
x = torch.zeros(4, 3, 1080, 1920) # Raw data
x_pad, pads = pad_to(x, 16) # Padded data, feed this to your network
x_unpad = unpad(x_pad, pads) # Un-pad the network output to recover the original shape
print('Original: ', x.shape)
print('Padded: ', x_pad.shape)
print('Recovered: ', x_unpad.shape)
输出:
Original: torch.Size([4, 3, 1080, 1920])
Padded: torch.Size([4, 3, 1088, 1920])
Recovered: torch.Size([4, 3, 1080, 1920])
参考:https://github.com/seoungwugoh/STM/blob/905f11492a6692dd0d0fa395881a8ec09b211a36/helpers.py#L33