我想确认这些计算 Dice Loss 的方法中哪一个是正确的
I want to confirm which of these methods to calculate Dice Loss is correct
所以我有4种计算骰子损失的方法,其中3种返回相同的结果,所以我可以断定其中1种计算错误,但我想和你们确认一下:
import torch
torch.manual_seed(0)
inputs = torch.rand((3,1,224,224))
target = torch.rand((3,1,224,224))
方法一:展平张量
def method1(inputs, target):
inputs = inputs.reshape( -1)
target = target.reshape( -1)
intersection = (inputs * target).sum()
union = inputs.sum() + target.sum()
dice = (2. * intersection) / (union + 1e-8)
dice = dice.sum()
print("method1", dice)
方法二:除batch size之外的张量展平,对所有dims求和
def method2(inputs, target):
num = target.shape[0]
inputs = inputs.reshape(num, -1)
target = target.reshape(num, -1)
intersection = (inputs * target).sum()
union = inputs.sum() + target.sum()
dice = (2. * intersection) / (union + 1e-8)
dice = dice.sum()/num
print("method2", dice)
方法三:除batch size之外的张量展平,sum dim 1
def method3(inputs, target):
num = target.shape[0]
inputs = inputs.reshape(num, -1)
target = target.reshape(num, -1)
intersection = (inputs * target).sum(1)
union = inputs.sum(1) + target.sum(1)
dice = (2. * intersection) / (union + 1e-8)
dice = dice.sum()/num
print("method3", dice)
方法 4:不要展平张量
def method4(inputs, target):
intersection = (inputs * target).sum()
union = inputs.sum() + target.sum()
dice = (2. * intersection) / (union + 1e-8)
print("method4", dice)
method1(inputs, target)
method2(inputs, target)
method3(inputs, target)
method4(inputs, target)
方法1,3和4打印:0.5006
方法二打印:0.1669
这是有道理的,因为我在 3 个维度上展平了输入和目标,而忽略了批量大小,然后我对展平产生的所有 2 个维度求和,而不仅仅是暗淡的 1
方法4似乎是最优化的
首先,您需要确定报告的骰子分数:批次中所有样本的骰子分数(方法 1,2 和 4)或批次中每个样本的平均骰子分数(方法 3)。
如果我没记错的话,您想使用方法 3 - 您想要优化批次中每个样本的骰子得分,而不是“全局”骰子得分:假设您在“简单”中有一个“困难”样本“ 批。 “困难”样本的错误分类像素将忽略 w.r.t 所有其他像素。但是如果你分别看每个样本的骰子分数,那么“难”样本的骰子分数就不能忽略
所以我有4种计算骰子损失的方法,其中3种返回相同的结果,所以我可以断定其中1种计算错误,但我想和你们确认一下:
import torch
torch.manual_seed(0)
inputs = torch.rand((3,1,224,224))
target = torch.rand((3,1,224,224))
方法一:展平张量
def method1(inputs, target):
inputs = inputs.reshape( -1)
target = target.reshape( -1)
intersection = (inputs * target).sum()
union = inputs.sum() + target.sum()
dice = (2. * intersection) / (union + 1e-8)
dice = dice.sum()
print("method1", dice)
方法二:除batch size之外的张量展平,对所有dims求和
def method2(inputs, target):
num = target.shape[0]
inputs = inputs.reshape(num, -1)
target = target.reshape(num, -1)
intersection = (inputs * target).sum()
union = inputs.sum() + target.sum()
dice = (2. * intersection) / (union + 1e-8)
dice = dice.sum()/num
print("method2", dice)
方法三:除batch size之外的张量展平,sum dim 1
def method3(inputs, target):
num = target.shape[0]
inputs = inputs.reshape(num, -1)
target = target.reshape(num, -1)
intersection = (inputs * target).sum(1)
union = inputs.sum(1) + target.sum(1)
dice = (2. * intersection) / (union + 1e-8)
dice = dice.sum()/num
print("method3", dice)
方法 4:不要展平张量
def method4(inputs, target):
intersection = (inputs * target).sum()
union = inputs.sum() + target.sum()
dice = (2. * intersection) / (union + 1e-8)
print("method4", dice)
method1(inputs, target)
method2(inputs, target)
method3(inputs, target)
method4(inputs, target)
方法1,3和4打印:0.5006 方法二打印:0.1669
这是有道理的,因为我在 3 个维度上展平了输入和目标,而忽略了批量大小,然后我对展平产生的所有 2 个维度求和,而不仅仅是暗淡的 1
方法4似乎是最优化的
首先,您需要确定报告的骰子分数:批次中所有样本的骰子分数(方法 1,2 和 4)或批次中每个样本的平均骰子分数(方法 3)。
如果我没记错的话,您想使用方法 3 - 您想要优化批次中每个样本的骰子得分,而不是“全局”骰子得分:假设您在“简单”中有一个“困难”样本“ 批。 “困难”样本的错误分类像素将忽略 w.r.t 所有其他像素。但是如果你分别看每个样本的骰子分数,那么“难”样本的骰子分数就不能忽略