使用pytorch计算训练损失和验证损失的区别

Difference between the calculation of the training loss and validation loss using pytorch

我想将这个传统图像分类问题的以下代码用于我的回归问题。代码可以在这里找到:

GeeksforGeeks-Training Neural Networks with Validation using Pytorch

class Network(nn.Module):
    def __init__(self):
        super(Network,self).__init__()
        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(1,-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Network()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

epochs = 5

for e in range(epochs):
    train_loss = 0.0
    model.train()     # Optional when not using Model Specific layer
    for data, labels in trainloader:
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        
        optimizer.zero_grad()
        target = model(data)
        loss = criterion(target,labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    valid_loss = 0.0
    model.eval()     # Optional when not using Model Specific layer
    for data, labels in validloader:
        if torch.cuda.is_available():
            data, labels = data.cuda(), labels.cuda()
        
        target = model(data)
        loss = criterion(target,labels)
        valid_loss = loss.item() * data.size(0)

print(f'Epoch {e+1} \t\t Training Loss: {train_loss / len(trainloader)} \t\t Validation Loss: {valid_loss / len(validloader)}')

我能理解为什么这个例子中training loss要加起来再除以训练数据的长度,但是我不明白为什么validation loss也不加起来再除以长度。如果我没理解错的话,这里的validation loss是用最后一批的validation loss再乘以batch size的长度来计算的。

计算验证损失的方法是否正确?假设我使用特定于回归的指标(例如 MSE 而不是 CrossEntropyLoss 等),我可以将代码用于我的回归问题吗?

是的,您可以将代码用于回归任务。代码示例的目标是 one-hot 向量,或者在 MNIST 示例中是数字 0 到 9,它们表示 类。在回归案例中,您可以从中得到一个标量。损失函数,也就是例子中的cross-entropy,在你的情况下可以用MSE代替。

我假设此示例中的验证损失仅通过从单个数据点外推到所有其他数据点来估算。 由于 data.size 代表批量大小,即使平均也只会在丢失单个数据点的情况下得出。

然而,在网页上,验证损失是在验证集中的所有数据点上计算的,这是应该的。