TypeError: load_checkpoint() missing 1 required positional argument: 'ckpt_path'
TypeError: load_checkpoint() missing 1 required positional argument: 'ckpt_path'
请帮我解决以下问题。我想从检查点恢复训练所以我进入
python main.py --config cfgs/cifar10.yaml --resume checkpoint/cifar10/ckpt.pth.tar
在控制台中,但它不起作用。我收到错误消息。
Traceback (most recent call last): File "main.py", line 287, in <module>
main() File "main.py", line 85, in main
start_epoch, best_nmi = load_checkpoint(model, dim_loss, optimizer, args.resume) TypeError: load_checkpoint() missing 1 required positional argument: 'ckpt_path'
部分代码如下
# argparser
parser = argparse.ArgumentParser(description='PyTorch Implementation of DCCM')
parser.add_argument('--resume', default=None, type=str, help='resume from a checkpoint')
parser.add_argument('--config', default='cfgs/config.yaml/', help='set sconfiguration file')
parser.add_argument('--small_bs', default=32, type=int)
parser.add_argument('--input_size', default=96, type=int)
parser.add_argument('--split', default=None, type=int, help='divide the large forward batch to avoid OOM')
恢复训练
if args.resume:
logger.info("=> loading checkpoint '{}'".format(args.resume))
start_epoch, best_nmi = load_checkpoint(model, dim_loss, optimizer, args.resume) #line85
保存并加载检查点
def load_checkpoint(model, dim_loss, classifier, optimizer, ckpt_path):
checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['model'])
dim_loss.load_state_dict(checkpoint['dim_loss'])
optimizer.load_state_dict(checkpoint['optimizer'])
best_nmi = checkpoint['best_nmi']
start_epoch = checkpoint['epoch']
return start_epoch, best_nmi
def save_checkpoint(state, is_best_nmi, filename):
torch.save(state, filename+'.pth.tar')
if is_best_nmi:
shutil.copyfile(filename+'.pth.tar', filename+'_best_nmi.pth.tar')
谢谢
当你定义 load_checkpoint
时它需要 5 个参数,但在第 85 行你只传递了 4 个,并且像 none 一样它们有一个默认值,所有这些都是必需的。当您将参数作为位置参数传递时,您正在用 ckpt_path
的值填充 optimizer
。您有 2 个选择:取出函数定义的 classifier
参数(因为它永远不会使用)或给它一个值,默认值或在运行时。
请帮我解决以下问题。我想从检查点恢复训练所以我进入
python main.py --config cfgs/cifar10.yaml --resume checkpoint/cifar10/ckpt.pth.tar
在控制台中,但它不起作用。我收到错误消息。
Traceback (most recent call last): File "main.py", line 287, in <module>
main() File "main.py", line 85, in main
start_epoch, best_nmi = load_checkpoint(model, dim_loss, optimizer, args.resume) TypeError: load_checkpoint() missing 1 required positional argument: 'ckpt_path'
部分代码如下
# argparser
parser = argparse.ArgumentParser(description='PyTorch Implementation of DCCM')
parser.add_argument('--resume', default=None, type=str, help='resume from a checkpoint')
parser.add_argument('--config', default='cfgs/config.yaml/', help='set sconfiguration file')
parser.add_argument('--small_bs', default=32, type=int)
parser.add_argument('--input_size', default=96, type=int)
parser.add_argument('--split', default=None, type=int, help='divide the large forward batch to avoid OOM')
恢复训练
if args.resume:
logger.info("=> loading checkpoint '{}'".format(args.resume))
start_epoch, best_nmi = load_checkpoint(model, dim_loss, optimizer, args.resume) #line85
保存并加载检查点
def load_checkpoint(model, dim_loss, classifier, optimizer, ckpt_path):
checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['model'])
dim_loss.load_state_dict(checkpoint['dim_loss'])
optimizer.load_state_dict(checkpoint['optimizer'])
best_nmi = checkpoint['best_nmi']
start_epoch = checkpoint['epoch']
return start_epoch, best_nmi
def save_checkpoint(state, is_best_nmi, filename):
torch.save(state, filename+'.pth.tar')
if is_best_nmi:
shutil.copyfile(filename+'.pth.tar', filename+'_best_nmi.pth.tar')
谢谢
当你定义 load_checkpoint
时它需要 5 个参数,但在第 85 行你只传递了 4 个,并且像 none 一样它们有一个默认值,所有这些都是必需的。当您将参数作为位置参数传递时,您正在用 ckpt_path
的值填充 optimizer
。您有 2 个选择:取出函数定义的 classifier
参数(因为它永远不会使用)或给它一个值,默认值或在运行时。