TypeError: load_checkpoint() missing 1 required positional argument: 'ckpt_path'

TypeError: load_checkpoint() missing 1 required positional argument: 'ckpt_path'

请帮我解决以下问题。我想从检查点恢复训练所以我进入

python main.py --config cfgs/cifar10.yaml --resume checkpoint/cifar10/ckpt.pth.tar

在控制台中,但它不起作用。我收到错误消息。

Traceback (most recent call last):   File "main.py", line 287, in <module>
    main()   File "main.py", line 85, in main
    start_epoch, best_nmi = load_checkpoint(model, dim_loss, optimizer, args.resume) TypeError: load_checkpoint() missing 1 required positional argument: 'ckpt_path'

部分代码如下

# argparser
parser = argparse.ArgumentParser(description='PyTorch Implementation of DCCM')
parser.add_argument('--resume', default=None, type=str, help='resume from a checkpoint')
parser.add_argument('--config', default='cfgs/config.yaml/', help='set sconfiguration file')
parser.add_argument('--small_bs', default=32, type=int)
parser.add_argument('--input_size', default=96, type=int)
parser.add_argument('--split', default=None, type=int, help='divide the large forward batch to avoid OOM')

恢复训练

if args.resume:
    logger.info("=> loading checkpoint '{}'".format(args.resume))
    start_epoch, best_nmi = load_checkpoint(model, dim_loss, optimizer, args.resume) #line85  

保存并加载检查点

 def load_checkpoint(model, dim_loss, classifier, optimizer, ckpt_path):

    checkpoint = torch.load(ckpt_path)
    
    model.load_state_dict(checkpoint['model'])
    dim_loss.load_state_dict(checkpoint['dim_loss'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    best_nmi = checkpoint['best_nmi']
    start_epoch = checkpoint['epoch']

    return start_epoch, best_nmi


def save_checkpoint(state, is_best_nmi, filename):
    torch.save(state, filename+'.pth.tar')
    if is_best_nmi:
        shutil.copyfile(filename+'.pth.tar', filename+'_best_nmi.pth.tar')

谢谢

当你定义 load_checkpoint 时它需要 5 个参数,但在第 85 行你只传递了 4 个,并且像 none 一样它们有一个默认值,所有这些都是必需的。当您将参数作为位置参数传递时,您正在用 ckpt_path 的值填充 optimizer。您有 2 个选择:取出函数定义的 classifier 参数(因为它永远不会使用)或给它一个值,默认值或在运行时。