PyTorch 中的 运行 损失是什么,它是如何计算的
What is running loss in PyTorch and how is it calculated
我查看了 PyTorch 文档中的 this 教程以了解迁移学习。有一行我没看懂。
使用loss = criterion(outputs, labels)
计算损失后,使用running_loss += loss.item() * inputs.size(0)
计算运行损失,最后使用running_loss / dataset_sizes[phase]
计算epoch损失。
loss.item()
不应该用于整个小批量(如果我错了请纠正我)。即,如果 batch_size
为 4,则 loss.item()
将给出整组 4 张图像的损失。如果这是真的,为什么在计算 running_loss
时 loss.item()
与 inputs.size(0)
相乘?在这种情况下,这一步是不是像一个额外的乘法?
如有任何帮助,我们将不胜感激。谢谢!
是因为CrossEntropy
或其他损失函数给出的损失除以元素个数即reduction参数默认为mean
。
torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
因此,loss.item()
包含整个小批量的损失,但除以批量大小。这就是为什么在计算 running_loss
.
时 loss.item()
乘以 inputs.size(0)
给出的批量大小
if the batch_size is 4, loss.item() would give the loss for the entire set of 4 images
这取决于 loss
的计算方式。请记住,loss
是一个张量,就像其他所有张量一样。一般来说,PyTorch API 默认 return avg loss
"The losses are averaged across observations for each minibatch."
t.item()
对于张量 t
只是将其转换为 python 的默认 float32。
更重要的是,如果您是 PyTorch 的新手,了解我们使用 t.item()
来维持 运行 损失而不是 t
可能会对您有所帮助,因为 PyTorch 张量存储其值的历史可能会很快使您的 GPU 超载。
我查看了 PyTorch 文档中的 this 教程以了解迁移学习。有一行我没看懂。
使用loss = criterion(outputs, labels)
计算损失后,使用running_loss += loss.item() * inputs.size(0)
计算运行损失,最后使用running_loss / dataset_sizes[phase]
计算epoch损失。
loss.item()
不应该用于整个小批量(如果我错了请纠正我)。即,如果 batch_size
为 4,则 loss.item()
将给出整组 4 张图像的损失。如果这是真的,为什么在计算 running_loss
时 loss.item()
与 inputs.size(0)
相乘?在这种情况下,这一步是不是像一个额外的乘法?
如有任何帮助,我们将不胜感激。谢谢!
是因为CrossEntropy
或其他损失函数给出的损失除以元素个数即reduction参数默认为mean
。
torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')
因此,loss.item()
包含整个小批量的损失,但除以批量大小。这就是为什么在计算 running_loss
.
loss.item()
乘以 inputs.size(0)
给出的批量大小
if the batch_size is 4, loss.item() would give the loss for the entire set of 4 images
这取决于 loss
的计算方式。请记住,loss
是一个张量,就像其他所有张量一样。一般来说,PyTorch API 默认 return avg loss
"The losses are averaged across observations for each minibatch."
t.item()
对于张量 t
只是将其转换为 python 的默认 float32。
更重要的是,如果您是 PyTorch 的新手,了解我们使用 t.item()
来维持 运行 损失而不是 t
可能会对您有所帮助,因为 PyTorch 张量存储其值的历史可能会很快使您的 GPU 超载。