使用 Detectron2 训练自定义 COCO 数据集时出错
Error Training Custom COCO Dataset with Detectron2
我正在尝试在 PyTorch 上使用 Detectron2 训练自定义 COCO 格式数据集。我的数据集是具有上述 COCO 格式的 json 个文件,“注释”部分中的每个项目如下所示:
设置Detectron2和注册训练&验证数据集的代码如下:
from detectron2.data.datasets import register_coco_instances
for d in ["train", "validation"]:
register_coco_instances(f"segmentation_{d}", {}, f"/content/drive/MyDrive/Segmentation Annotations/{d}.json", f"/content/drive/MyDrive/Segmentation Annotations/imgs")
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("segmentation_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 1000
cfg.SOLVER.STEPS = []
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 20
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
然而,当我 运行 训练时,我在第一次迭代后得到以下错误:
KeyError Traceback (most recent call last)
<ipython-input-12-2aaec108c313> in <module>()
17 trainer = DefaultTrainer(cfg)
18 trainer.resume_or_load(resume=False)
---> 19 trainer.train()
8 frames
/usr/local/lib/python3.7/dist-packages/detectron2/engine/defaults.py in train(self)
482 OrderedDict of results, if evaluation is enabled. Otherwise None.
483 """
--> 484 super().train(self.start_iter, self.max_iter)
485 if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
486 assert hasattr(
/usr/local/lib/python3.7/dist-packages/detectron2/engine/train_loop.py in train(self, start_iter, max_iter)
147 for self.iter in range(start_iter, max_iter):
148 self.before_step()
--> 149 self.run_step()
150 self.after_step()
151 # self.iter == max_iter can be used by `after_train` to
/usr/local/lib/python3.7/dist-packages/detectron2/engine/defaults.py in run_step(self)
492 def run_step(self):
493 self._trainer.iter = self.iter
--> 494 self._trainer.run_step()
495
496 @classmethod
/usr/local/lib/python3.7/dist-packages/detectron2/engine/train_loop.py in run_step(self)
265 If you want to do something with the data, you can wrap the dataloader.
266 """
--> 267 data = next(self._data_loader_iter)
268 data_time = time.perf_counter() - start
269
/usr/local/lib/python3.7/dist-packages/detectron2/data/common.py in __iter__(self)
232
233 def __iter__(self):
--> 234 for d in self.dataset:
235 w, h = d["width"], d["height"]
236 bucket_id = 0 if w > h else 1
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in __next__(self)
519 if self._sampler_iter is None:
520 self._reset()
--> 521 data = self._next_data()
522 self._num_yielded += 1
523 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
1181 if len(self._task_info[self._rcvd_idx]) == 2:
1182 data = self._task_info.pop(self._rcvd_idx)[1]
-> 1183 return self._process_data(data)
1184
1185 assert not self._shutdown and self._tasks_outstanding > 0
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
1227 self._try_put_index()
1228 if isinstance(data, ExceptionWrapper):
-> 1229 data.reraise()
1230 return data
1231
/usr/local/lib/python3.7/dist-packages/torch/_utils.py in reraise(self)
423 # have message field
424 raise self.exc_type(message=msg)
--> 425 raise self.exc_type(msg)
426
427
KeyError: Caught KeyError in DataLoader worker process 1.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 28, in fetch
data.append(next(self.dataset_iter))
File "/usr/local/lib/python3.7/dist-packages/detectron2/data/common.py", line 201, in __iter__
yield self.dataset[idx]
File "/usr/local/lib/python3.7/dist-packages/detectron2/data/common.py", line 90, in __getitem__
data = self._map_func(self._dataset[cur_idx])
File "/usr/local/lib/python3.7/dist-packages/detectron2/utils/serialize.py", line 26, in __call__
return self._obj(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/detectron2/data/dataset_mapper.py", line 189, in __call__
self._transform_annotations(dataset_dict, transforms, image_shape)
File "/usr/local/lib/python3.7/dist-packages/detectron2/data/dataset_mapper.py", line 132, in _transform_annotations
annos, image_shape, mask_format=self.instance_mask_format
File "/usr/local/lib/python3.7/dist-packages/detectron2/data/detection_utils.py", line 400, in annotations_to_instances
segms = [obj["segmentation"] for obj in annos]
File "/usr/local/lib/python3.7/dist-packages/detectron2/data/detection_utils.py", line 400, in <listcomp>
segms = [obj["segmentation"] for obj in annos]
KeyError: 'segmentation'
你们都知道为什么会发生这种情况,如果是这样,可以采取什么措施来解决它?欢迎任何意见。
谢谢!
如果不查看完整的注释文件很难给出具体的答案,但是当尝试访问不在字典中的键时会引发 KeyError
异常。从您发布的错误消息来看,此密钥似乎是 'segmentation'
.
这不在您的代码片段中,但在进行网络训练之前,您是否使用已注册的数据集完成了任何 exploration/inspections?进行一些基本的探索或检查会暴露数据集的任何问题,因此您可以在开发过程的早期修复它们(而不是让培训师发现它们,在这种情况下,错误消息可能会变得冗长且令人困惑)。
无论如何,对于您的具体问题,您可以使用已注册的训练数据集并检查是否所有注释都具有 'segmentation'
字段。下面是执行此操作的简单代码片段。
# Register datasets
from detectron2.data.datasets import register_coco_instances
for d in ["train", "validation"]:
register_coco_instances(f"segmentation_{d}", {}, f"/content/drive/MyDrive/Segmentation Annotations/{d}.json", f"/content/drive/MyDrive/Segmentation Annotations/imgs")
# Check if all annotations in the registered training set have the segmentation field
from detectron2.data import DatasetCatalog
dataset_dicts_train = DatasetCatalog.get('segmentation_train')
for d in dataset_dicts_train:
for obj in d['annotations']:
if 'segmentation' not in obj:
print(f'{d["file_name"]} has an annotation with no segmentation field')
如果某些图像的注释中没有 'segmentation'
字段,那会很奇怪,但这表明您的上游注释过程存在问题。
我正在尝试在 PyTorch 上使用 Detectron2 训练自定义 COCO 格式数据集。我的数据集是具有上述 COCO 格式的 json 个文件,“注释”部分中的每个项目如下所示:
设置Detectron2和注册训练&验证数据集的代码如下:
from detectron2.data.datasets import register_coco_instances
for d in ["train", "validation"]:
register_coco_instances(f"segmentation_{d}", {}, f"/content/drive/MyDrive/Segmentation Annotations/{d}.json", f"/content/drive/MyDrive/Segmentation Annotations/imgs")
from detectron2.engine import DefaultTrainer
from detectron2.config import get_cfg
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("segmentation_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 1000
cfg.SOLVER.STEPS = []
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 20
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
然而,当我 运行 训练时,我在第一次迭代后得到以下错误:
KeyError Traceback (most recent call last)
<ipython-input-12-2aaec108c313> in <module>()
17 trainer = DefaultTrainer(cfg)
18 trainer.resume_or_load(resume=False)
---> 19 trainer.train()
8 frames
/usr/local/lib/python3.7/dist-packages/detectron2/engine/defaults.py in train(self)
482 OrderedDict of results, if evaluation is enabled. Otherwise None.
483 """
--> 484 super().train(self.start_iter, self.max_iter)
485 if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
486 assert hasattr(
/usr/local/lib/python3.7/dist-packages/detectron2/engine/train_loop.py in train(self, start_iter, max_iter)
147 for self.iter in range(start_iter, max_iter):
148 self.before_step()
--> 149 self.run_step()
150 self.after_step()
151 # self.iter == max_iter can be used by `after_train` to
/usr/local/lib/python3.7/dist-packages/detectron2/engine/defaults.py in run_step(self)
492 def run_step(self):
493 self._trainer.iter = self.iter
--> 494 self._trainer.run_step()
495
496 @classmethod
/usr/local/lib/python3.7/dist-packages/detectron2/engine/train_loop.py in run_step(self)
265 If you want to do something with the data, you can wrap the dataloader.
266 """
--> 267 data = next(self._data_loader_iter)
268 data_time = time.perf_counter() - start
269
/usr/local/lib/python3.7/dist-packages/detectron2/data/common.py in __iter__(self)
232
233 def __iter__(self):
--> 234 for d in self.dataset:
235 w, h = d["width"], d["height"]
236 bucket_id = 0 if w > h else 1
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in __next__(self)
519 if self._sampler_iter is None:
520 self._reset()
--> 521 data = self._next_data()
522 self._num_yielded += 1
523 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
1181 if len(self._task_info[self._rcvd_idx]) == 2:
1182 data = self._task_info.pop(self._rcvd_idx)[1]
-> 1183 return self._process_data(data)
1184
1185 assert not self._shutdown and self._tasks_outstanding > 0
/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
1227 self._try_put_index()
1228 if isinstance(data, ExceptionWrapper):
-> 1229 data.reraise()
1230 return data
1231
/usr/local/lib/python3.7/dist-packages/torch/_utils.py in reraise(self)
423 # have message field
424 raise self.exc_type(message=msg)
--> 425 raise self.exc_type(msg)
426
427
KeyError: Caught KeyError in DataLoader worker process 1.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 28, in fetch
data.append(next(self.dataset_iter))
File "/usr/local/lib/python3.7/dist-packages/detectron2/data/common.py", line 201, in __iter__
yield self.dataset[idx]
File "/usr/local/lib/python3.7/dist-packages/detectron2/data/common.py", line 90, in __getitem__
data = self._map_func(self._dataset[cur_idx])
File "/usr/local/lib/python3.7/dist-packages/detectron2/utils/serialize.py", line 26, in __call__
return self._obj(*args, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/detectron2/data/dataset_mapper.py", line 189, in __call__
self._transform_annotations(dataset_dict, transforms, image_shape)
File "/usr/local/lib/python3.7/dist-packages/detectron2/data/dataset_mapper.py", line 132, in _transform_annotations
annos, image_shape, mask_format=self.instance_mask_format
File "/usr/local/lib/python3.7/dist-packages/detectron2/data/detection_utils.py", line 400, in annotations_to_instances
segms = [obj["segmentation"] for obj in annos]
File "/usr/local/lib/python3.7/dist-packages/detectron2/data/detection_utils.py", line 400, in <listcomp>
segms = [obj["segmentation"] for obj in annos]
KeyError: 'segmentation'
你们都知道为什么会发生这种情况,如果是这样,可以采取什么措施来解决它?欢迎任何意见。
谢谢!
如果不查看完整的注释文件很难给出具体的答案,但是当尝试访问不在字典中的键时会引发 KeyError
异常。从您发布的错误消息来看,此密钥似乎是 'segmentation'
.
这不在您的代码片段中,但在进行网络训练之前,您是否使用已注册的数据集完成了任何 exploration/inspections?进行一些基本的探索或检查会暴露数据集的任何问题,因此您可以在开发过程的早期修复它们(而不是让培训师发现它们,在这种情况下,错误消息可能会变得冗长且令人困惑)。
无论如何,对于您的具体问题,您可以使用已注册的训练数据集并检查是否所有注释都具有 'segmentation'
字段。下面是执行此操作的简单代码片段。
# Register datasets
from detectron2.data.datasets import register_coco_instances
for d in ["train", "validation"]:
register_coco_instances(f"segmentation_{d}", {}, f"/content/drive/MyDrive/Segmentation Annotations/{d}.json", f"/content/drive/MyDrive/Segmentation Annotations/imgs")
# Check if all annotations in the registered training set have the segmentation field
from detectron2.data import DatasetCatalog
dataset_dicts_train = DatasetCatalog.get('segmentation_train')
for d in dataset_dicts_train:
for obj in d['annotations']:
if 'segmentation' not in obj:
print(f'{d["file_name"]} has an annotation with no segmentation field')
如果某些图像的注释中没有 'segmentation'
字段,那会很奇怪,但这表明您的上游注释过程存在问题。