如何通过 PyTorch 中的 Mask R-CNN 预测为图像生成准确的掩码?

How to generate accurate masks for an image from Mask R-CNN prediction in PyTorch?

我训练了一个 Mask RCNN 网络来分割苹果。我能够加载权重并为我的测试图像生成预测。正在生成的蒙版似乎在正确的位置,但蒙版本身没有真正的形式..它看起来只是一堆像素

训练是基于这个paper, and here is the github link to code being used to train and generate weights


预测代码如下。 (我省略了创建路径变量和分配路径的部分)

import os
import glob
import numpy as np
import pandas as pd
import cv2 as cv
import fileinput

import torch
import torch.utils.data
import torchvision

from data.apple_dataset import AppleDataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

import utility.utils as utils
import utility.transforms as T

from PIL import Image
from matplotlib import pyplot as plt
%matplotlib inline

def get_transform(train):
    transforms = []
    if train:
    return T.Compose(transforms)

def get_maskrcnn_model_instance(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model

num_classes = 2
device = torch.device('cpu')

model = get_maskrcnn_model_instance(num_classes)
checkpoint = torch.load('model_49.pth', map_location=device)
model.load_state_dict(checkpoint['model'], strict=False)

dataset_test = AppleDataset(test_image_files_path, get_transform(train=False))
img, _ = dataset_test[1]

with torch.no_grad():
    prediction = model([img.to(device)])


Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())

(unable to load image here since its over 2MB.  

Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())

这是原始图像的 Imgur link。下面是其中一个实例的预测蒙版


[{'boxes': tensor([[ 966.8143, 1633.7491, 1106.7389, 1787.6367],
          [1418.7872, 1467.0619, 1732.0828, 1796.1527],
          [1608.0396, 2064.6482, 1710.7534, 2206.5535],
          [2326.3750, 1690.3418, 2542.2112, 1883.2626],
          [2213.2024, 1864.3657, 2299.8933, 1963.0178],
          [1112.9083, 1732.5953, 1236.7600, 1823.0170],
          [1150.8256,  614.0334, 1218.8584,  711.4094],
          [ 942.7086,  794.6043, 1138.2318, 1008.0430],
          [1065.4371,  723.0493, 1192.7570,  870.3763],
          [1002.3103,  883.4616, 1146.9994, 1006.6841],
          [1315.2816, 1680.8625, 1531.3210, 1989.3317],
          [1244.5769, 1925.0903, 1459.5417, 2175.3252],
          [1725.2191, 2082.6187, 1934.0227, 2274.2952],
          [ 936.3065, 1554.3765, 1014.2722, 1659.4229],
          [ 934.8851, 1541.3331, 1090.4736, 1657.3751],
          [2486.0120,  776.4577, 2547.2329,  847.9725],
          [2336.1675,  698.6327, 2508.6492,  921.4550],
          [2368.4077, 1954.1102, 2448.4004, 2049.5796],
          [1899.1403, 1775.2371, 2035.7561, 1962.6923],
          [2176.0664, 1075.1553, 2398.6084, 1267.2555],
          [2274.8899,  641.6769, 2395.9634,  791.3353],
          [2535.1580,  874.4780, 2642.8213,  966.4614],
          [2183.4236,  619.9688, 2288.5676,  758.6825],
          [2183.9832, 1122.9382, 2334.9583, 1263.3226],
          [1135.7822,  779.0529, 1225.9871,  890.0135],
          [ 317.3954, 1328.6995,  397.3900, 1467.7740],
          [ 945.4811, 1833.3708,  997.2318, 1878.8607],
          [1992.4447,  679.4969, 2134.6667,  835.8701],
          [1098.5416, 1452.7799, 1429.1808, 1771.4460],
          [1657.3193, 1405.5405, 1781.6273, 1574.6780],
          [1443.8911, 1747.1544, 1739.0361, 2076.9724],
          [1092.6003, 1165.3340, 1206.0881, 1383.8314],
          [2466.4170, 1945.5931, 2555.1931, 2039.8368],
          [2561.8508, 1616.2659, 2672.1033, 1742.2332],
          [1894.4806,  907.9214, 2097.1875, 1182.6473],
          [2321.5005, 1701.3344, 2368.3699, 1865.3914],
          [2180.0781,  567.5969, 2344.6357,  763.4360],
          [1845.7612,  668.6808, 2045.2688,  899.8501],
          [1858.9216, 2145.7097, 1961.8870, 2273.5088],
          [ 261.4607, 1314.0154,  396.9288, 1486.9498],
          [2488.1682, 1585.2357, 2669.0178, 1794.9926],
          [2696.9548,  936.0087, 2802.7961, 1025.2294],
          [1593.6837, 1489.8641, 1720.3124, 1627.8135],
          [2517.9468,  857.1713, 2567.1125,  929.4335],
          [1943.2167,  636.3422, 2151.4419,  853.8924],
          [2143.5664, 1100.0521, 2308.1570, 1290.7125],
          [2140.9231, 1947.9692, 2238.6956, 2000.6249],
          [1461.6316, 2105.2593, 1559.7675, 2189.0264],
          [2114.0781,  374.8153, 2222.8838,  559.9851],
          [2350.5320,  726.5779, 2466.8140,  878.2617]]),
  'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1]),
  'scores': tensor([0.9916, 0.9841, 0.9669, 0.9337, 0.9118, 0.7729, 0.7202, 0.7193, 0.6928,
          0.6872, 0.6690, 0.5913, 0.4877, 0.4683, 0.3781, 0.3327, 0.3164, 0.2364,
          0.1696, 0.1692, 0.1502, 0.1365, 0.1316, 0.1171, 0.1119, 0.1094, 0.1041,
          0.0865, 0.0853, 0.0835, 0.0822, 0.0816, 0.0797, 0.0796, 0.0788, 0.0780,
          0.0757, 0.0736, 0.0736, 0.0689, 0.0681, 0.0644, 0.0642, 0.0630, 0.0612,
          0.0598, 0.0563, 0.0531, 0.0525, 0.0522]),
  'masks': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],

          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],

          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],


          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],

          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]],

          [[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]]]])}]

来自 Mask R-CNN 的预测具有以下结构:

During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as follows:

boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between 0 and H and 0 and W  
labels (Int64Tensor[N]): the predicted labels for each image  
scores (Tensor[N]): the scores or each prediction  
masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range.


img_cv = cv2.imread('input.jpg', cv2.COLOR_BGR2RGB)

for i in range(len(prediction[0]['masks'])):
    # iterate over masks
    mask = prediction[0]['masks'][i, 0]
    mask = mask.mul(255).byte().cpu().numpy()
    contours, _ = cv2.findContours(
            mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
    cv2.drawContours(img_cv, contours, -1, (255, 0, 0), 2, cv2.LINE_AA)

cv2.imshow('img output', img_cv)
