如何通过 PyTorch 中的 Mask R-CNN 预测为图像生成准确的掩码?
How to generate accurate masks for an image from Mask R-CNN prediction in PyTorch?
我训练了一个 Mask RCNN 网络来分割苹果。我能够加载权重并为我的测试图像生成预测。正在生成的蒙版似乎在正确的位置,但蒙版本身没有真正的形式..它看起来只是一堆像素
训练是基于这个paper, and here is the github link to code being used to train and generate weights
的数据集完成的
预测代码如下。 (我省略了创建路径变量和分配路径的部分)
import os
import glob
import numpy as np
import pandas as pd
import cv2 as cv
import fileinput
import torch
import torch.utils.data
import torchvision
from data.apple_dataset import AppleDataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import utility.utils as utils
import utility.transforms as T
from PIL import Image
from matplotlib import pyplot as plt
%matplotlib inline
def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
def get_maskrcnn_model_instance(num_classes):
# load an instance segmentation model pre-trained pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
return model
num_classes = 2
device = torch.device('cpu')
model = get_maskrcnn_model_instance(num_classes)
checkpoint = torch.load('model_49.pth', map_location=device)
model.load_state_dict(checkpoint['model'], strict=False)
dataset_test = AppleDataset(test_image_files_path, get_transform(train=False))
img, _ = dataset_test[1]
model.eval()
with torch.no_grad():
prediction = model([img.to(device)])
prediction
Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())
(unable to load image here since its over 2MB.
Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())
这是原始图像的 Imgur link。下面是其中一个实例的预测蒙版
此外,能否请您帮助我理解如下所示生成的预测矩阵的数据结构。我如何访问蒙版以生成显示所有蒙版的单个图像???
[{'boxes': tensor([[ 966.8143, 1633.7491, 1106.7389, 1787.6367],
[1418.7872, 1467.0619, 1732.0828, 1796.1527],
[1608.0396, 2064.6482, 1710.7534, 2206.5535],
[2326.3750, 1690.3418, 2542.2112, 1883.2626],
[2213.2024, 1864.3657, 2299.8933, 1963.0178],
[1112.9083, 1732.5953, 1236.7600, 1823.0170],
[1150.8256, 614.0334, 1218.8584, 711.4094],
[ 942.7086, 794.6043, 1138.2318, 1008.0430],
[1065.4371, 723.0493, 1192.7570, 870.3763],
[1002.3103, 883.4616, 1146.9994, 1006.6841],
[1315.2816, 1680.8625, 1531.3210, 1989.3317],
[1244.5769, 1925.0903, 1459.5417, 2175.3252],
[1725.2191, 2082.6187, 1934.0227, 2274.2952],
[ 936.3065, 1554.3765, 1014.2722, 1659.4229],
[ 934.8851, 1541.3331, 1090.4736, 1657.3751],
[2486.0120, 776.4577, 2547.2329, 847.9725],
[2336.1675, 698.6327, 2508.6492, 921.4550],
[2368.4077, 1954.1102, 2448.4004, 2049.5796],
[1899.1403, 1775.2371, 2035.7561, 1962.6923],
[2176.0664, 1075.1553, 2398.6084, 1267.2555],
[2274.8899, 641.6769, 2395.9634, 791.3353],
[2535.1580, 874.4780, 2642.8213, 966.4614],
[2183.4236, 619.9688, 2288.5676, 758.6825],
[2183.9832, 1122.9382, 2334.9583, 1263.3226],
[1135.7822, 779.0529, 1225.9871, 890.0135],
[ 317.3954, 1328.6995, 397.3900, 1467.7740],
[ 945.4811, 1833.3708, 997.2318, 1878.8607],
[1992.4447, 679.4969, 2134.6667, 835.8701],
[1098.5416, 1452.7799, 1429.1808, 1771.4460],
[1657.3193, 1405.5405, 1781.6273, 1574.6780],
[1443.8911, 1747.1544, 1739.0361, 2076.9724],
[1092.6003, 1165.3340, 1206.0881, 1383.8314],
[2466.4170, 1945.5931, 2555.1931, 2039.8368],
[2561.8508, 1616.2659, 2672.1033, 1742.2332],
[1894.4806, 907.9214, 2097.1875, 1182.6473],
[2321.5005, 1701.3344, 2368.3699, 1865.3914],
[2180.0781, 567.5969, 2344.6357, 763.4360],
[1845.7612, 668.6808, 2045.2688, 899.8501],
[1858.9216, 2145.7097, 1961.8870, 2273.5088],
[ 261.4607, 1314.0154, 396.9288, 1486.9498],
[2488.1682, 1585.2357, 2669.0178, 1794.9926],
[2696.9548, 936.0087, 2802.7961, 1025.2294],
[1593.6837, 1489.8641, 1720.3124, 1627.8135],
[2517.9468, 857.1713, 2567.1125, 929.4335],
[1943.2167, 636.3422, 2151.4419, 853.8924],
[2143.5664, 1100.0521, 2308.1570, 1290.7125],
[2140.9231, 1947.9692, 2238.6956, 2000.6249],
[1461.6316, 2105.2593, 1559.7675, 2189.0264],
[2114.0781, 374.8153, 2222.8838, 559.9851],
[2350.5320, 726.5779, 2466.8140, 878.2617]]),
'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1]),
'scores': tensor([0.9916, 0.9841, 0.9669, 0.9337, 0.9118, 0.7729, 0.7202, 0.7193, 0.6928,
0.6872, 0.6690, 0.5913, 0.4877, 0.4683, 0.3781, 0.3327, 0.3164, 0.2364,
0.1696, 0.1692, 0.1502, 0.1365, 0.1316, 0.1171, 0.1119, 0.1094, 0.1041,
0.0865, 0.0853, 0.0835, 0.0822, 0.0816, 0.0797, 0.0796, 0.0788, 0.0780,
0.0757, 0.0736, 0.0736, 0.0689, 0.0681, 0.0644, 0.0642, 0.0630, 0.0612,
0.0598, 0.0563, 0.0531, 0.0525, 0.0522]),
'masks': tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
...,
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]])}]
来自 Mask R-CNN 的预测具有以下结构:
During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]]
, one for each input image. The fields of the Dict
are as follows:
boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between 0 and H and 0 and W
labels (Int64Tensor[N]): the predicted labels for each image
scores (Tensor[N]): the scores or each prediction
masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range.
可以使用OpenCV的findContours
和drawContours
函数来绘制mask如下:
img_cv = cv2.imread('input.jpg', cv2.COLOR_BGR2RGB)
for i in range(len(prediction[0]['masks'])):
# iterate over masks
mask = prediction[0]['masks'][i, 0]
mask = mask.mul(255).byte().cpu().numpy()
contours, _ = cv2.findContours(
mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
cv2.drawContours(img_cv, contours, -1, (255, 0, 0), 2, cv2.LINE_AA)
cv2.imshow('img output', img_cv)
示例输出:
我训练了一个 Mask RCNN 网络来分割苹果。我能够加载权重并为我的测试图像生成预测。正在生成的蒙版似乎在正确的位置,但蒙版本身没有真正的形式..它看起来只是一堆像素
训练是基于这个paper, and here is the github link to code being used to train and generate weights
的数据集完成的预测代码如下。 (我省略了创建路径变量和分配路径的部分)
import os
import glob
import numpy as np
import pandas as pd
import cv2 as cv
import fileinput
import torch
import torch.utils.data
import torchvision
from data.apple_dataset import AppleDataset
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
import utility.utils as utils
import utility.transforms as T
from PIL import Image
from matplotlib import pyplot as plt
%matplotlib inline
def get_transform(train):
transforms = []
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
def get_maskrcnn_model_instance(num_classes):
# load an instance segmentation model pre-trained pre-trained on COCO
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
# now get the number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256
# and replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
return model
num_classes = 2
device = torch.device('cpu')
model = get_maskrcnn_model_instance(num_classes)
checkpoint = torch.load('model_49.pth', map_location=device)
model.load_state_dict(checkpoint['model'], strict=False)
dataset_test = AppleDataset(test_image_files_path, get_transform(train=False))
img, _ = dataset_test[1]
model.eval()
with torch.no_grad():
prediction = model([img.to(device)])
prediction
Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())
(unable to load image here since its over 2MB.
Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())
这是原始图像的 Imgur link。下面是其中一个实例的预测蒙版
此外,能否请您帮助我理解如下所示生成的预测矩阵的数据结构。我如何访问蒙版以生成显示所有蒙版的单个图像???
[{'boxes': tensor([[ 966.8143, 1633.7491, 1106.7389, 1787.6367],
[1418.7872, 1467.0619, 1732.0828, 1796.1527],
[1608.0396, 2064.6482, 1710.7534, 2206.5535],
[2326.3750, 1690.3418, 2542.2112, 1883.2626],
[2213.2024, 1864.3657, 2299.8933, 1963.0178],
[1112.9083, 1732.5953, 1236.7600, 1823.0170],
[1150.8256, 614.0334, 1218.8584, 711.4094],
[ 942.7086, 794.6043, 1138.2318, 1008.0430],
[1065.4371, 723.0493, 1192.7570, 870.3763],
[1002.3103, 883.4616, 1146.9994, 1006.6841],
[1315.2816, 1680.8625, 1531.3210, 1989.3317],
[1244.5769, 1925.0903, 1459.5417, 2175.3252],
[1725.2191, 2082.6187, 1934.0227, 2274.2952],
[ 936.3065, 1554.3765, 1014.2722, 1659.4229],
[ 934.8851, 1541.3331, 1090.4736, 1657.3751],
[2486.0120, 776.4577, 2547.2329, 847.9725],
[2336.1675, 698.6327, 2508.6492, 921.4550],
[2368.4077, 1954.1102, 2448.4004, 2049.5796],
[1899.1403, 1775.2371, 2035.7561, 1962.6923],
[2176.0664, 1075.1553, 2398.6084, 1267.2555],
[2274.8899, 641.6769, 2395.9634, 791.3353],
[2535.1580, 874.4780, 2642.8213, 966.4614],
[2183.4236, 619.9688, 2288.5676, 758.6825],
[2183.9832, 1122.9382, 2334.9583, 1263.3226],
[1135.7822, 779.0529, 1225.9871, 890.0135],
[ 317.3954, 1328.6995, 397.3900, 1467.7740],
[ 945.4811, 1833.3708, 997.2318, 1878.8607],
[1992.4447, 679.4969, 2134.6667, 835.8701],
[1098.5416, 1452.7799, 1429.1808, 1771.4460],
[1657.3193, 1405.5405, 1781.6273, 1574.6780],
[1443.8911, 1747.1544, 1739.0361, 2076.9724],
[1092.6003, 1165.3340, 1206.0881, 1383.8314],
[2466.4170, 1945.5931, 2555.1931, 2039.8368],
[2561.8508, 1616.2659, 2672.1033, 1742.2332],
[1894.4806, 907.9214, 2097.1875, 1182.6473],
[2321.5005, 1701.3344, 2368.3699, 1865.3914],
[2180.0781, 567.5969, 2344.6357, 763.4360],
[1845.7612, 668.6808, 2045.2688, 899.8501],
[1858.9216, 2145.7097, 1961.8870, 2273.5088],
[ 261.4607, 1314.0154, 396.9288, 1486.9498],
[2488.1682, 1585.2357, 2669.0178, 1794.9926],
[2696.9548, 936.0087, 2802.7961, 1025.2294],
[1593.6837, 1489.8641, 1720.3124, 1627.8135],
[2517.9468, 857.1713, 2567.1125, 929.4335],
[1943.2167, 636.3422, 2151.4419, 853.8924],
[2143.5664, 1100.0521, 2308.1570, 1290.7125],
[2140.9231, 1947.9692, 2238.6956, 2000.6249],
[1461.6316, 2105.2593, 1559.7675, 2189.0264],
[2114.0781, 374.8153, 2222.8838, 559.9851],
[2350.5320, 726.5779, 2466.8140, 878.2617]]),
'labels': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1]),
'scores': tensor([0.9916, 0.9841, 0.9669, 0.9337, 0.9118, 0.7729, 0.7202, 0.7193, 0.6928,
0.6872, 0.6690, 0.5913, 0.4877, 0.4683, 0.3781, 0.3327, 0.3164, 0.2364,
0.1696, 0.1692, 0.1502, 0.1365, 0.1316, 0.1171, 0.1119, 0.1094, 0.1041,
0.0865, 0.0853, 0.0835, 0.0822, 0.0816, 0.0797, 0.0796, 0.0788, 0.0780,
0.0757, 0.0736, 0.0736, 0.0689, 0.0681, 0.0644, 0.0642, 0.0630, 0.0612,
0.0598, 0.0563, 0.0531, 0.0525, 0.0522]),
'masks': tensor([[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
...,
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]],
[[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]]])}]
来自 Mask R-CNN 的预测具有以下结构:
During inference, the model requires only the input tensors, and returns the post-processed predictions as a
List[Dict[Tensor]]
, one for each input image. The fields of theDict
are as follows:
boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between 0 and H and 0 and W
labels (Int64Tensor[N]): the predicted labels for each image
scores (Tensor[N]): the scores or each prediction
masks (UInt8Tensor[N, 1, H, W]): the predicted masks for each instance, in 0-1 range.
可以使用OpenCV的findContours
和drawContours
函数来绘制mask如下:
img_cv = cv2.imread('input.jpg', cv2.COLOR_BGR2RGB)
for i in range(len(prediction[0]['masks'])):
# iterate over masks
mask = prediction[0]['masks'][i, 0]
mask = mask.mul(255).byte().cpu().numpy()
contours, _ = cv2.findContours(
mask.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
cv2.drawContours(img_cv, contours, -1, (255, 0, 0), 2, cv2.LINE_AA)
cv2.imshow('img output', img_cv)
示例输出: