PyTorch:Dropout(?)导致训练+验证的不同模型收敛 V.仅训练

PyTorch: Dropout (?) causes different model convergence for training+validation V. training-only

我们遇到了一个非常奇怪的问题。我们将完全相同的模型测试为两种不同的“执行”设置。在第一种情况下,给定一定数量的 epoch,我们使用 mini-batches 训练一个 epoch,然后我们按照相同的标准在验证集上进行测试。然后,我们进入下一个纪元。显然,在每个训练时期之前,我们使用 model.train(),在验证之前我们打开 model.eval()。

然后我们采用完全相同的模型(相同的初始值、相同的数据集、相同的 epoch 等),我们只是在每个 epoch 之后训练它而不进行验证。

仅查看训练集的性能,我们观察到,即使我们修复了所有种子,两个训练过程的演进方式也不同,并产生完全不同的指标结果(损失、准确性等)。具体来说,仅训练程序的性能较差。

我们还观察到以下情况:

我们花了一整天的时间来解决这个问题(不,我们无法避免使用 dropout)。有人知道如何解决这个问题吗?

非常感谢!

p.s。这里是用于训练的代码(CPU)。

def sigmoid(x):
    return 1 / (1 + torch.exp(-x))


def _run(model, EPOCHS, training_data_in, validation_data_in=None):
    
    def train_fn(train_dataloader, model, optimizer, criterion):

        running_loss = 0.
        running_accuracy = 0.
        running_tp = 0.
        running_tn = 0.
        running_fp = 0.
        running_fn = 0.
        
        model.train()

        for batch_idx, (ecg, spo2, labels) in enumerate(train_dataloader, 1):

            optimizer.zero_grad() 
                
            outputs = model(ecg)

            loss = criterion(outputs, labels)
                        
            loss.backward() # calculate the gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step() # update the network weights
                                                
            running_loss += loss.item()
            predicted = torch.round(sigmoid(outputs.data)) # here determining the sigmoid, not included in the model
            
            running_accuracy += (predicted == labels).sum().item() / labels.size(0)   
            
            fp = ((predicted - labels) == 1.).sum().item() 
            fn = ((predicted - labels) == -1.).sum().item()
            tp = ((predicted + labels) == 2.).sum().item()
            tn = ((predicted + labels) == 0.).sum().item()
            running_tp += tp
            running_fp += fp
            running_tn += tn
            running_fn += fn
            
        retval = {'loss':running_loss / batch_idx,
                'accuracy':running_accuracy / batch_idx,
                'tp':running_tp,
                'tn':running_tn,
                'fp':running_fp,
                'fn':running_fn
                }
            
        return retval
            

        
    def valid_fn(valid_dataloader, model, criterion):

        running_loss = 0.
        running_accuracy = 0.
        running_tp = 0.
        running_tn = 0.
        running_fp = 0.
        running_fn = 0.

        model.eval()
        
        for batch_idx, (ecg, spo2, labels) in enumerate(valid_dataloader, 1):

            outputs = model(ecg)

            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            predicted = torch.round(sigmoid(outputs.data)) # here determining the sigmoid, not included in the model

            running_accuracy += (predicted == labels).sum().item() / labels.size(0)  
            
            fp = ((predicted - labels) == 1.).sum().item()
            fn = ((predicted - labels) == -1.).sum().item()
            tp = ((predicted + labels) == 2.).sum().item()
            tn = ((predicted + labels) == 0.).sum().item()
            running_tp += tp
            running_fp += fp
            running_tn += tn
            running_fn += fn
            
        retval = {'loss':running_loss / batch_idx,
                'accuracy':running_accuracy / batch_idx,
                'tp':running_tp,
                'tn':running_tn,
                'fp':running_fp,
                'fn':running_fn
                }
            
        return retval
    
    
    
    # Defining data loaders

    train_dataloader = torch.utils.data.DataLoader(training_data_in, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
    
    if validation_data_in != None:
        validation_dataloader = torch.utils.data.DataLoader(validation_data_in, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)


    # Defining the loss function
    criterion = nn.BCEWithLogitsLoss()
    
    
    # Defining the optimizer
    import torch.optim as optim
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, amsgrad=False, eps=1e-07) 


    # Training code
    
    metrics_history = {"loss":[], "accuracy":[], "precision":[], "recall":[], "f1":[], "specificity":[], "accuracy_bis":[], "tp":[], "tn":[], "fp":[], "fn":[],
                    "val_loss":[], "val_accuracy":[], "val_precision":[], "val_recall":[], "val_f1":[], "val_specificity":[], "val_accuracy_bis":[], "val_tp":[], "val_tn":[], "val_fp":[], "val_fn":[],}
    
    train_begin = time.time()
    for epoch in range(EPOCHS):
        start = time.time()

        print("EPOCH:", epoch+1)

        train_metrics = train_fn(train_dataloader=train_dataloader, 
                                model=model,
                                optimizer=optimizer, 
                                criterion=criterion)
        
        metrics_history["loss"].append(train_metrics["loss"])
        metrics_history["accuracy"].append(train_metrics["accuracy"])
        metrics_history["tp"].append(train_metrics["tp"])
        metrics_history["tn"].append(train_metrics["tn"])
        metrics_history["fp"].append(train_metrics["fp"])
        metrics_history["fn"].append(train_metrics["fn"])
        
        precision = train_metrics["tp"] / (train_metrics["tp"] + train_metrics["fp"]) if train_metrics["tp"] > 0 else 0
        recall = train_metrics["tp"] / (train_metrics["tp"] + train_metrics["fn"]) if train_metrics["tp"] > 0 else 0
        specificity = train_metrics["tn"] / (train_metrics["tn"] + train_metrics["fp"]) if train_metrics["tn"] > 0 else 0
        f1 = 2*precision*recall / (precision + recall) if precision*recall > 0 else 0
        metrics_history["precision"].append(precision)
        metrics_history["recall"].append(recall)
        metrics_history["f1"].append(f1)
        metrics_history["specificity"].append(specificity)
        
        
        
        if validation_data_in != None:    
            # Calculate the metrics on the validation data, in the same way as done for training
            with torch.no_grad(): # don't keep track of the info necessary to calculate the gradients

                val_metrics = valid_fn(valid_dataloader=validation_dataloader, 
                                    model=model,
                                    criterion=criterion)

                metrics_history["val_loss"].append(val_metrics["loss"])
                metrics_history["val_accuracy"].append(val_metrics["accuracy"])
                metrics_history["val_tp"].append(val_metrics["tp"])
                metrics_history["val_tn"].append(val_metrics["tn"])
                metrics_history["val_fp"].append(val_metrics["fp"])
                metrics_history["val_fn"].append(val_metrics["fn"])

                val_precision = val_metrics["tp"] / (val_metrics["tp"] + val_metrics["fp"]) if val_metrics["tp"] > 0 else 0
                val_recall = val_metrics["tp"] / (val_metrics["tp"] + val_metrics["fn"]) if val_metrics["tp"] > 0 else 0
                val_specificity = val_metrics["tn"] / (val_metrics["tn"] + val_metrics["fp"]) if val_metrics["tn"] > 0 else 0
                val_f1 = 2*val_precision*val_recall / (val_precision + val_recall) if val_precision*val_recall > 0 else 0
                metrics_history["val_precision"].append(val_precision)
                metrics_history["val_recall"].append(val_recall)
                metrics_history["val_f1"].append(val_f1)
                metrics_history["val_specificity"].append(val_specificity)


            print("  > Training/validation loss:", round(train_metrics['loss'], 4), round(val_metrics['loss'], 4))
            print("  > Training/validation accuracy:", round(train_metrics['accuracy'], 4), round(val_metrics['accuracy'], 4))
            print("  > Training/validation precision:", round(precision, 4), round(val_precision, 4))
            print("  > Training/validation recall:", round(recall, 4), round(val_recall, 4))
            print("  > Training/validation f1:", round(f1, 4), round(val_f1, 4))
            print("  > Training/validation specificity:", round(specificity, 4), round(val_specificity, 4))
        else:
            print("  > Training loss:", round(train_metrics['loss'], 4))
            print("  > Training accuracy:", round(train_metrics['accuracy'], 4))
            print("  > Training precision:", round(precision, 4))
            print("  > Training recall:", round(recall, 4))
            print("  > Training f1:", round(f1, 4))
            print("  > Training specificity:", round(specificity, 4))


        print("Completed in:", round(time.time() - start, 1), "seconds \n")

    print("Training completed in:", round((time.time()- train_begin)/60, 1), "minutes")    

    
    
    # Save the model weights
    torch.save(model.state_dict(), './nnet_model.pt')
    
    
    # Save the metrics history
    torch.save(metrics_history, 'training_history')

这里是初始化模型和设置种子的函数,在每次执行“_run”代码之前调用:

def reinit_model():
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    net = Net() # the model
    return net

好的,我找到问题了。 这个问题是由这样一个事实决定的,显然,运行 一些随机种子的评估发生了变化,这会影响训练阶段。

因此解决方案如下:

  • 在函数“_run()”的开头,将所有种子状态设置为所需的值,例如 42。然后,将这些种子保存到磁盘。
  • 在函数“train_fn()”的开头,从磁盘读取种子状态,并设置它们
  • 在函数“train_fn()”的末尾,将种子状态保存到磁盘

例如,运行 在带有 XLA 的 TPU 上,必须使用以下指令:

  • 在函数“_run()”的开头:xm.set_rng_state(42), xm.save(xm.get_rng_state(), 'xm_seed')
  • 在函数“train_fn()”的开头:xm.set_rng_state(torch.load('xm_seed'), device=device)(您也可以在此处打印种子以用于验证 xm.master_print(xm.get_rng_state())
  • 在函数“train_fn_()”的末尾:xm.save(xm.get_rng_state(), 'xm_seed')