如何在 PyTorch 中计算自举交叉熵损失?
How do I compute bootstrapped cross entropy loss in PyTorch?
我读过一些论文,它们使用一种叫做“自举交叉熵损失”的东西来训练他们的分割网络。这个想法是只关注最难的 k%(比如 15%)的像素,以提高学习性能,尤其是当容易的像素占主导地位时。
目前我使用的是标准交叉熵:
loss = F.binary_cross_entropy(mask, gt)
如何在 PyTorch 中有效地将其转换为引导版本?
通常我们还会在损失中添加一个“warm-up”周期,这样网络就可以学习先适应容易的区域,然后过渡到较难的区域。
此实现从 k=100
开始并持续 20000 次迭代,然后线性衰减到 k=15
再进行 50000 次迭代。
class BootstrappedCE(nn.Module):
def __init__(self, start_warm=20000, end_warm=70000, top_p=0.15):
super().__init__()
self.start_warm = start_warm
self.end_warm = end_warm
self.top_p = top_p
def forward(self, input, target, it):
if it < self.start_warm:
return F.cross_entropy(input, target), 1.0
raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
num_pixels = raw_loss.numel()
if it > self.end_warm:
this_p = self.top_p
else:
this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
return loss.mean(), this_p
对自我回答的补充(为了未来的自我和 API 与 PyTorch 的对等);
可以像这样首先实现 functional
版本(在 original torch.nn.functional.cross_entropy
中提供一些额外的参数)(而且我更喜欢 reduction
是 callable
而不是预定义的字符串):
import typing
import torch
def bootstrapped_cross_entropy(
inputs,
targets,
iteration,
p: float,
warmup: typing.Union[typing.Callable[[float, int], float], int] = -1,
weight=None,
ignore_index=-100,
reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.mean,
):
if not 0 < p < 1:
raise ValueError("p should be in [0, 1] range, got: {}".format(p))
if isinstance(warmup, int):
this_p = 1.0 if iteration < warmup else p
elif callable(warmup):
this_p = warmup(p, iteration)
else:
raise ValueError(
"warmup should be int or callable, got {}".format(type(warmup))
)
# Shortcut
if this_p == 1.0:
return torch.nn.functional.cross_entropy(
inputs, targets, weight, ignore_index=ignore_index, reduction=reduction
)
raw_loss = torch.nn.functional.cross_entropy(
inputs, targets, weight=weight, ignore_index=ignore_index, reduction="none"
).view(-1)
num_pixels = raw_loss.numel()
loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
return reduction(loss)
也可以将 warmup
指定为 callable
(采用 p
和当前 iteration
)或 int
,这样可以灵活或轻松地进行调度。
并使 _WeightedLoss
和 iteration
的 class 在每次调用期间自动递增(因此只需要传递 inputs
和 targets
) :
class BoostrappedCrossEntropy(torch.nn.modules.loss._WeightedLoss):
def __init__(
self,
p: float,
warmup: typing.Union[typing.Callable[[float, int], float], int] = -1,
weight=None,
ignore_index=-100,
reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.mean,
):
self.p = p
self.warmup = warmup
self.ignore_index = ignore_index
self._current_iteration = -1
super().__init__(weight, size_average=None, reduce=None, reduction=reduction)
def forward(self, inputs, targets):
self._current_iteration += 1
return bootstrapped_cross_entropy(
inputs,
targets,
self._current_iteration,
self.p,
self.warmup,
self.weight,
self.ignore_index,
self.reduction,
)
我读过一些论文,它们使用一种叫做“自举交叉熵损失”的东西来训练他们的分割网络。这个想法是只关注最难的 k%(比如 15%)的像素,以提高学习性能,尤其是当容易的像素占主导地位时。
目前我使用的是标准交叉熵:
loss = F.binary_cross_entropy(mask, gt)
如何在 PyTorch 中有效地将其转换为引导版本?
通常我们还会在损失中添加一个“warm-up”周期,这样网络就可以学习先适应容易的区域,然后过渡到较难的区域。
此实现从 k=100
开始并持续 20000 次迭代,然后线性衰减到 k=15
再进行 50000 次迭代。
class BootstrappedCE(nn.Module):
def __init__(self, start_warm=20000, end_warm=70000, top_p=0.15):
super().__init__()
self.start_warm = start_warm
self.end_warm = end_warm
self.top_p = top_p
def forward(self, input, target, it):
if it < self.start_warm:
return F.cross_entropy(input, target), 1.0
raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
num_pixels = raw_loss.numel()
if it > self.end_warm:
this_p = self.top_p
else:
this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
return loss.mean(), this_p
可以像这样首先实现 functional
版本(在 original torch.nn.functional.cross_entropy
中提供一些额外的参数)(而且我更喜欢 reduction
是 callable
而不是预定义的字符串):
import typing
import torch
def bootstrapped_cross_entropy(
inputs,
targets,
iteration,
p: float,
warmup: typing.Union[typing.Callable[[float, int], float], int] = -1,
weight=None,
ignore_index=-100,
reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.mean,
):
if not 0 < p < 1:
raise ValueError("p should be in [0, 1] range, got: {}".format(p))
if isinstance(warmup, int):
this_p = 1.0 if iteration < warmup else p
elif callable(warmup):
this_p = warmup(p, iteration)
else:
raise ValueError(
"warmup should be int or callable, got {}".format(type(warmup))
)
# Shortcut
if this_p == 1.0:
return torch.nn.functional.cross_entropy(
inputs, targets, weight, ignore_index=ignore_index, reduction=reduction
)
raw_loss = torch.nn.functional.cross_entropy(
inputs, targets, weight=weight, ignore_index=ignore_index, reduction="none"
).view(-1)
num_pixels = raw_loss.numel()
loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
return reduction(loss)
也可以将 warmup
指定为 callable
(采用 p
和当前 iteration
)或 int
,这样可以灵活或轻松地进行调度。
并使 _WeightedLoss
和 iteration
的 class 在每次调用期间自动递增(因此只需要传递 inputs
和 targets
) :
class BoostrappedCrossEntropy(torch.nn.modules.loss._WeightedLoss):
def __init__(
self,
p: float,
warmup: typing.Union[typing.Callable[[float, int], float], int] = -1,
weight=None,
ignore_index=-100,
reduction: typing.Callable[[torch.Tensor], torch.Tensor] = torch.mean,
):
self.p = p
self.warmup = warmup
self.ignore_index = ignore_index
self._current_iteration = -1
super().__init__(weight, size_average=None, reduce=None, reduction=reduction)
def forward(self, inputs, targets):
self._current_iteration += 1
return bootstrapped_cross_entropy(
inputs,
targets,
self._current_iteration,
self.p,
self.warmup,
self.weight,
self.ignore_index,
self.reduction,
)