Dice 出现在论文 V-Net 中,用于医疗图像检测,针对正负样本不均衡的场景,用于评估预测和GT相似度。就是IOU

Dice as Loss In V-Net

Dice

以下内容来自VNet

D=2iNpigiiNpi2+iNgi2D = \frac{2\displaystyle\sum_{i}^{N} p_i g_i}{\displaystyle\sum_{i}^{N} p_{i}^{2}+\displaystyle\sum_{i}^{N} g_{i}^{2}}

pp是概率,取值[0,1][0,1]gg是ground truth,取值0,10,1,所以pi×gip_i\times g_i就算交集XYX\cap Y了,分子乘2因为分母的总数算了两遍。

这个式子也很好求导

Dpj=2[gi(iNpi2+gi2)2pj(iNpigi)(iNpi2+gi2)2]\frac{\partial D}{\partial p_j} = 2\bigg[\frac{g_i(\sum_i^N p_i^2 +g_i^2) - 2p_j(\sum_i^N p_i g_i) }{(\sum_i^N p_i^2+g_i^2)^2}\bigg]

Dice Loss

1D1-D即可,可以加个Smooth,即 D中分子分母各加11(有的地方写γ\gamma)

Generalised Dice Loss

Dice Loss 论文中。Dice function的分母可以不是平方和

LDL=12XY+1X+Y+1L_{DL} = 1 - \frac{2|X\cap Y| + 1}{|X| + |Y| + 1}

具体计算就是分子 X 与 Y 按元素乘得到XY|X\cap Y|X.sum() + Y.sum()得到X+Y|X|+|Y|

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class DiceLoss(nn.Module):

def __init__(self, eps=1e-6):
super(DiceLoss, self).__init__()
self.eps = eps

def forward(self, pred: torch.Tensor, gt, mask=None):
"""
:param pred: [N, NUM_CLASSES, H, W]
:param gt: [N, NUM_CLASSES, H, W]
:param mask: [N, H, W]
:return: loss
"""
assert pred.dim() == 4, pred.dim()
assert pred.shape == gt.shape

if mask is not None:
assert (
pred.shape == mask.shape
), f"pred.shape {pred.shape} mask.shape {mask.shape}"

intersection = (pred * gt * mask).sum()
union = (pred * mask).sum() + (gt * mask).sum()
else:
intersection = (pred * gt).sum()
union = pred.sum() + gt.sum()

loss = 1 - (2.0 * intersection + self.eps) / (union + self.eps)
return loss

对于多分类的Dice Loss,直接把各个分类的Dice Loss加权求和

DL=12k=1KwknNpkngkn+ϵk=1KwknNpkn+gkn+ϵDL = 1- \frac{2\sum_{k=1}^K w_k \sum_n^N p_{kn} g_{kn} + \epsilon}{\sum_{k=1}^K w_k \sum_n^N p_{kn}+g_{kn}+\epsilon}

k为k classes,n为每个元素

样本不均衡的Loss 对比

根据参考文章,都和Cross Entropy差不多,只是针对样本不均衡的情况好一点。

虽然focal_loss,dice_loss没有比cross_entropy表现多出色,也就是说并没有有效的解决不平衡性问题,但至少证明了它们跟cross_entropy一样是有效的

相关论文: