使用 Detectron2 训练自定义 COCO 数据集时出错

Error Training Custom COCO Dataset with Detectron2

我正在尝试在 PyTorch 上使用 Detectron2 训练自定义 COCO 格式数据集。我的数据集是具有上述 COCO 格式的 json 个文件,“注释”部分中的每个项目如下所示:

设置Detectron2和注册训练&验证数据集的代码如下:

from detectron2.data.datasets import register_coco_instances
for d in ["train", "validation"]:
    register_coco_instances(f"segmentation_{d}", {}, f"/content/drive/MyDrive/Segmentation Annotations/{d}.json", f"/content/drive/MyDrive/Segmentation Annotations/imgs")

from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg

cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("segmentation_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025  
cfg.SOLVER.MAX_ITER = 1000    
cfg.SOLVER.STEPS = []        
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 20  

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

然而,当我 运行 训练时,我在第一次迭代后得到以下错误:

KeyError                                  Traceback (most recent call last)
<ipython-input-12-2aaec108c313> in <module>()
     17 trainer = DefaultTrainer(cfg)
     18 trainer.resume_or_load(resume=False)
---> 19 trainer.train()

8 frames
/usr/local/lib/python3.7/dist-packages/detectron2/engine/defaults.py in train(self)
    482             OrderedDict of results, if evaluation is enabled. Otherwise None.
    483         """
--> 484         super().train(self.start_iter, self.max_iter)
    485         if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
    486             assert hasattr(

/usr/local/lib/python3.7/dist-packages/detectron2/engine/train_loop.py in train(self, start_iter, max_iter)
    147                 for self.iter in range(start_iter, max_iter):
    148                     self.before_step()
--> 149                     self.run_step()
    150                     self.after_step()
    151                 # self.iter == max_iter can be used by `after_train` to

/usr/local/lib/python3.7/dist-packages/detectron2/engine/defaults.py in run_step(self)
    492     def run_step(self):
    493         self._trainer.iter = self.iter
--> 494         self._trainer.run_step()
    495 
    496     @classmethod

/usr/local/lib/python3.7/dist-packages/detectron2/engine/train_loop.py in run_step(self)
    265         If you want to do something with the data, you can wrap the dataloader.
    266         """
--> 267         data = next(self._data_loader_iter)
    268         data_time = time.perf_counter() - start
    269 

/usr/local/lib/python3.7/dist-packages/detectron2/data/common.py in __iter__(self)
    232 
    233     def __iter__(self):
--> 234         for d in self.dataset:
    235             w, h = d["width"], d["height"]
    236             bucket_id = 0 if w > h else 1

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    519             if self._sampler_iter is None:
    520                 self._reset()
--> 521             data = self._next_data()
    522             self._num_yielded += 1
    523             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
   1181             if len(self._task_info[self._rcvd_idx]) == 2:
   1182                 data = self._task_info.pop(self._rcvd_idx)[1]
-> 1183                 return self._process_data(data)
   1184 
   1185             assert not self._shutdown and self._tasks_outstanding > 0

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1227         self._try_put_index()
   1228         if isinstance(data, ExceptionWrapper):
-> 1229             data.reraise()
   1230         return data
   1231 

/usr/local/lib/python3.7/dist-packages/torch/_utils.py in reraise(self)
    423             # have message field
    424             raise self.exc_type(message=msg)
--> 425         raise self.exc_type(msg)
    426 
    427 

KeyError: Caught KeyError in DataLoader worker process 1.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 28, in fetch
    data.append(next(self.dataset_iter))
  File "/usr/local/lib/python3.7/dist-packages/detectron2/data/common.py", line 201, in __iter__
    yield self.dataset[idx]
  File "/usr/local/lib/python3.7/dist-packages/detectron2/data/common.py", line 90, in __getitem__
    data = self._map_func(self._dataset[cur_idx])
  File "/usr/local/lib/python3.7/dist-packages/detectron2/utils/serialize.py", line 26, in __call__
    return self._obj(*args, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/detectron2/data/dataset_mapper.py", line 189, in __call__
    self._transform_annotations(dataset_dict, transforms, image_shape)
  File "/usr/local/lib/python3.7/dist-packages/detectron2/data/dataset_mapper.py", line 132, in _transform_annotations
    annos, image_shape, mask_format=self.instance_mask_format
  File "/usr/local/lib/python3.7/dist-packages/detectron2/data/detection_utils.py", line 400, in annotations_to_instances
    segms = [obj["segmentation"] for obj in annos]
  File "/usr/local/lib/python3.7/dist-packages/detectron2/data/detection_utils.py", line 400, in <listcomp>
    segms = [obj["segmentation"] for obj in annos]
KeyError: 'segmentation'

你们都知道为什么会发生这种情况,如果是这样,可以采取什么措施来解决它?欢迎任何意见。

谢谢!

如果不查看完整的注释文件很难给出具体的答案,但是当尝试访问不在字典中的键时会引发 KeyError 异常。从您发布的错误消息来看,此密钥似乎是 'segmentation'.

这不在您的代码片段中,但在进行网络训练之前,您是否使用已注册的数据集完成了任何 exploration/inspections?进行一些基本的探索或检查会暴露数据集的任何问题,因此您可以在开发过程的早期修复它们(而不是让培训师发现它们,在这种情况下,错误消息可能会变得冗长且令人困惑)。

无论如何,对于您的具体问题,您可以使用已注册的训练数据集并检查是否所有注释都具有 'segmentation' 字段。下面是执行此操作的简单代码片段。

# Register datasets
from detectron2.data.datasets import register_coco_instances
for d in ["train", "validation"]:
    register_coco_instances(f"segmentation_{d}", {}, f"/content/drive/MyDrive/Segmentation Annotations/{d}.json", f"/content/drive/MyDrive/Segmentation Annotations/imgs")

# Check if all annotations in the registered training set have the segmentation field
from detectron2.data import DatasetCatalog

dataset_dicts_train = DatasetCatalog.get('segmentation_train')

for d in dataset_dicts_train:
    for obj in d['annotations']:
        if 'segmentation' not in obj:
            print(f'{d["file_name"]} has an annotation with no segmentation field')

如果某些图像的注释中没有 'segmentation' 字段,那会很奇怪,但这表明您的上游注释过程存在问题。