如何知道训练好的模型是正确的?
How to know the trained model is correct?
我使用PyTorch Lightning进行模型训练,期间我使用ModelCheckpoint
保存加载点。最后,我想知道模型加载是否正确。如果您需要更多信息,请告诉我?
checkpoint_callback = ModelCheckpoint(
filename='tb1000_{epoch: 02d}-{step}',
monitor='val/acc@1',
save_top_k=5,
mode='max')
wandb_logger = pl.loggers.wandb.WandbLogger(
name=run_name,
project=args.project,
entity=args.entity,
offline=args.offline,
log_model='all')
model = BYOL(**args.__dict__, num_classes=dm.num_classes)
trainer = pl.Trainer.from_argparse_args(args,
logger=wandb_logger, callbacks=[checkpoint_callback])
trainer.fit(model, dm)
# Loading and testing
model_test = BYOL(**args.__dict__, num_classes=dm.num_classes)
path = "/tb100_epoch= 819-step=39359.ckpt"
model_test.load_from_checkpoint(path)
load_from_checkpoint()
将 return 一个具有训练权重的模型,因此您需要将其分配给一个新变量。
model_test = model_test.load_from_checkpoint(path)
或
model_test = BYOL.load_from_checkpoint(path)
我使用PyTorch Lightning进行模型训练,期间我使用ModelCheckpoint
保存加载点。最后,我想知道模型加载是否正确。如果您需要更多信息,请告诉我?
checkpoint_callback = ModelCheckpoint(
filename='tb1000_{epoch: 02d}-{step}',
monitor='val/acc@1',
save_top_k=5,
mode='max')
wandb_logger = pl.loggers.wandb.WandbLogger(
name=run_name,
project=args.project,
entity=args.entity,
offline=args.offline,
log_model='all')
model = BYOL(**args.__dict__, num_classes=dm.num_classes)
trainer = pl.Trainer.from_argparse_args(args,
logger=wandb_logger, callbacks=[checkpoint_callback])
trainer.fit(model, dm)
# Loading and testing
model_test = BYOL(**args.__dict__, num_classes=dm.num_classes)
path = "/tb100_epoch= 819-step=39359.ckpt"
model_test.load_from_checkpoint(path)
load_from_checkpoint()
将 return 一个具有训练权重的模型,因此您需要将其分配给一个新变量。
model_test = model_test.load_from_checkpoint(path)
或
model_test = BYOL.load_from_checkpoint(path)