如何在 Detectron2 中使用 DefaultTrainer 保存模型?
How to save a model using DefaultTrainer in Detectron2?
如何使用 DefaultTrainer 在 Detectron2 中保存检查点?
这是我的设置:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = (DatasetLabels.TRAIN,)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 273 # Number of output classes
cfg.OUTPUT_DIR = "outputs"
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025#0.00025 # Learning Rate
cfg.SOLVER.MAX_ITER = 10000 # 20000 MAx Iterations
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # Batch Size
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
# Save the model
from detectron2.checkpoint import DetectionCheckpointer, Checkpointer
checkpointer = DetectionCheckpointer(trainer, save_dir=cfg.OUTPUT_DIR)
checkpointer.save("mymodel_0")
我收到错误:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-94-c1116902655a> in <module>()
4 checkpointer = DetectionCheckpointer(trainer, save_dir=cfg.OUTPUT_DIR)
----> 5 checkpointer.save("mymodel_0")
/usr/local/lib/python3.6/dist-packages/fvcore/common/checkpoint.py in save(self, name, **kwargs)
102
103 data = {}
--> 104 data["model"] = self.model.state_dict()
105 for key, obj in self.checkpointables.items():
106 data[key] = obj.state_dict()
AttributeError: 'DefaultTrainer' object has no attribute 'state_dict'
文档:https://detectron2.readthedocs.io/en/latest/modules/checkpoint.html
checkpointer = DetectionCheckpointer(trainer.model, save_dir=cfg.OUTPUT_DIR)
是必经之路。
或者:
torch.save(trainer.model.state_dict(), os.path.join(cfg.OUTPUT_DIR, "mymodel.pth"))
你也可以试试:
torch.save(trainer.model, "MyCustom/path/mymodel.pth")
它将完整保存您的完整模型。要加载,您需要在 python 发行版中安装 detectron2。您可以像下面这样加载。
model = torch.load("MyCustom/path/mymodel.pth")
但是这样你就无法使用 detectron2 的默认预测器
如何使用 DefaultTrainer 在 Detectron2 中保存检查点? 这是我的设置:
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = (DatasetLabels.TRAIN,)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 273 # Number of output classes
cfg.OUTPUT_DIR = "outputs"
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025#0.00025 # Learning Rate
cfg.SOLVER.MAX_ITER = 10000 # 20000 MAx Iterations
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 # Batch Size
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
# Save the model
from detectron2.checkpoint import DetectionCheckpointer, Checkpointer
checkpointer = DetectionCheckpointer(trainer, save_dir=cfg.OUTPUT_DIR)
checkpointer.save("mymodel_0")
我收到错误:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-94-c1116902655a> in <module>()
4 checkpointer = DetectionCheckpointer(trainer, save_dir=cfg.OUTPUT_DIR)
----> 5 checkpointer.save("mymodel_0")
/usr/local/lib/python3.6/dist-packages/fvcore/common/checkpoint.py in save(self, name, **kwargs)
102
103 data = {}
--> 104 data["model"] = self.model.state_dict()
105 for key, obj in self.checkpointables.items():
106 data[key] = obj.state_dict()
AttributeError: 'DefaultTrainer' object has no attribute 'state_dict'
文档:https://detectron2.readthedocs.io/en/latest/modules/checkpoint.html
checkpointer = DetectionCheckpointer(trainer.model, save_dir=cfg.OUTPUT_DIR)
是必经之路。
或者:
torch.save(trainer.model.state_dict(), os.path.join(cfg.OUTPUT_DIR, "mymodel.pth"))
你也可以试试:
torch.save(trainer.model, "MyCustom/path/mymodel.pth")
它将完整保存您的完整模型。要加载,您需要在 python 发行版中安装 detectron2。您可以像下面这样加载。
model = torch.load("MyCustom/path/mymodel.pth")
但是这样你就无法使用 detectron2 的默认预测器