将蒙版恢复为图像

Restoring mask to image

我的 PyTorch 模型为三个 classes 中的每一个输出一个具有值 (0,1,2) 的分割图像。在准备set的过程中,我把黑色映射到0,红色映射到1,白色映射到2。我有两个问题:

  1. 如何显示每个 class 代表什么?例如看一下图像: 我目前正在使用以下方法来显示每个 class:

         output = net(input)
    
         input = input.cpu().squeeze()
         input = transforms.ToPILImage()(input)
    
         probs = F.softmax(output, dim=1)
         probs = probs.squeeze(0)
    
         full_mask = probs.squeeze().cpu().numpy()
    
         fig, (ax0, ax1, ax2, ax3, ax4) = plt.subplots(1, 5, figsize=(20,10), sharey=True)
    
         ax0.set_title('Input Image')
         ax1.set_title('Background Class')
         ax2.set_title('Neuron Class')
         ax3.set_title('Dendrite Class')
         ax4.set_title('Predicted Mask')
    
         ax0.imshow(input)
         ax1.imshow(full_mask[0, :, :].squeeze())
         ax2.imshow(full_mask[1, :, :].squeeze())
         ax3.imshow(full_mask[2, :, :].squeeze())
    
         full_mask = np.argmax(full_mask, 0)
         img = mask_to_image(full_mask)
    

但是 classes 之间似乎有共享像素,有没有更好的方法来显示这个(我希望第一张图片仅包含背景 class,第二张仅包含神经元的 class 和树突的第三个 class)?

2.My 第二个问题是关于从蒙版生成黑白图像,目前蒙版的形状为 (512,512) 并具有以下值:

[[0 0 0 ... 0 0 0]
 [0 0 0 ... 2 0 0]
 [0 0 0 ... 2 2 0]
 ...
 [2 1 2 ... 2 2 2]
 [2 1 2 ... 2 2 2]
 [0 2 0 ... 2 2 2]]

结果如下所示:

由于我使用此代码转换为图像:

def mask_to_image(mask):
   return Image.fromarray((mask).astype(np.uint8))

But there appears to be shared pixels between the classes, is there a better way to show this (I want the first image to only of the background class, the the second only of the neuron class and the third only of the dendrite class)?

是的,您可以沿 0th 维度取 argmax,因此具有最高 logit(非标准化概率)的那个将是 1,其余将为零:

output = net(input)

binary_mask = torch.argmax(output, dim=0).cpu().numpy()
ax.set_title('Neuron Class')
ax.imshow(binary_mask == 0)

My second question is about generating a black, red and white image from the mask, currently the mask is of shape (512,512) and has the following values

您可以将 [0, 1, 2] 值散布到 zero-th 轴中,使其成为 channel-wise。现在单个像素所有通道的 [0, 0, 0] 值为 black[255, 255, 255] 为白色,[255, 0, 0] 为红色(因为 PIL 为 RGB 格式):

import torch

tensor = torch.randint(high=3, size=(512, 512))

red = tensor == 0
white = tensor == 2

zero_channel = red & white

image = torch.stack([zero_channel, white, white]).int().numpy() * 255
Image.fromarray((image).astype(np.uint8))