Dice 出现在论文 V-Net 中,用于医疗图像检测,针对正负样本不均衡的场景,用于评估预测和GT相似度。就是IOU
Dice as Loss In V-Net
Dice
以下内容来自VNet
D=i∑Npi2+i∑Ngi22i∑Npigi
p是概率,取值[0,1],g是ground truth,取值0,1,所以pi×gi就算交集X∩Y了,分子乘2因为分母的总数算了两遍。
这个式子也很好求导
∂pj∂D=2[(∑iNpi2+gi2)2gi(∑iNpi2+gi2)−2pj(∑iNpigi)]
Dice Loss
1−D即可,可以加个Smooth,即 D中分子分母各加1(有的地方写γ)
Generalised Dice Loss
Dice Loss 论文中。Dice function的分母可以不是平方和
LDL=1−∣X∣+∣Y∣+12∣X∩Y∣+1
具体计算就是分子 X 与 Y 按元素乘得到∣X∩Y∣,X.sum() + Y.sum()
得到∣X∣+∣Y∣
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
| class DiceLoss(nn.Module):
def __init__(self, eps=1e-6): super(DiceLoss, self).__init__() self.eps = eps
def forward(self, pred: torch.Tensor, gt, mask=None): """ :param pred: [N, NUM_CLASSES, H, W] :param gt: [N, NUM_CLASSES, H, W] :param mask: [N, H, W] :return: loss """ assert pred.dim() == 4, pred.dim() assert pred.shape == gt.shape
if mask is not None: assert ( pred.shape == mask.shape ), f"pred.shape {pred.shape} mask.shape {mask.shape}"
intersection = (pred * gt * mask).sum() union = (pred * mask).sum() + (gt * mask).sum() else: intersection = (pred * gt).sum() union = pred.sum() + gt.sum()
loss = 1 - (2.0 * intersection + self.eps) / (union + self.eps) return loss
|
对于多分类的Dice Loss,直接把各个分类的Dice Loss加权求和
DL=1−∑k=1Kwk∑nNpkn+gkn+ϵ2∑k=1Kwk∑nNpkngkn+ϵ
k为k classes,n为每个元素
样本不均衡的Loss 对比
根据参考文章,都和Cross Entropy差不多,只是针对样本不均衡的情况好一点。
虽然focal_loss,dice_loss没有比cross_entropy表现多出色,也就是说并没有有效的解决不平衡性问题,但至少证明了它们跟cross_entropy一样是有效的
相关论文: