|
|
|
@ -43,42 +43,13 @@ class DiceLoss(nn.Layer): |
|
|
|
|
return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MultiClassDiceLoss(nn.Layer): |
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
|
weight, |
|
|
|
|
batch=True, |
|
|
|
|
ignore_index=-1, |
|
|
|
|
do_softmax=False, |
|
|
|
|
**kwargs, ): |
|
|
|
|
super(MultiClassDiceLoss, self).__init__() |
|
|
|
|
self.ignore_index = ignore_index |
|
|
|
|
self.weight = weight |
|
|
|
|
self.do_softmax = do_softmax |
|
|
|
|
self.binary_diceloss = DiceLoss(batch) |
|
|
|
|
|
|
|
|
|
def forward(self, y_pred, y_true): |
|
|
|
|
if self.do_softmax: |
|
|
|
|
y_pred = paddle.nn.functional.softmax(y_pred, axis=1) |
|
|
|
|
y_true = F.one_hot(y_true.long(), y_pred.shape[1]).permute(0, 3, 1, 2) |
|
|
|
|
total_loss = 0.0 |
|
|
|
|
tmp_i = 0.0 |
|
|
|
|
for i in range(y_pred.shape[1]): |
|
|
|
|
if i != self.ignore_index: |
|
|
|
|
diceloss = self.binary_diceloss(y_pred[:, i, :, :], |
|
|
|
|
y_true[:, i, :, :]) |
|
|
|
|
total_loss += paddle.multiply(diceloss, self.weight[i]) |
|
|
|
|
tmp_i += 1.0 |
|
|
|
|
return total_loss / tmp_i |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DiceBCELoss(nn.Layer): |
|
|
|
|
"""Binary change detection task loss""" |
|
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
|
super(DiceBCELoss, self).__init__() |
|
|
|
|
self.bce_loss = nn.BCELoss() |
|
|
|
|
self.binnary_dice = DiceLoss() |
|
|
|
|
self.binary_dice = DiceLoss() |
|
|
|
|
|
|
|
|
|
def forward(self, scores, labels, do_sigmoid=True): |
|
|
|
|
if len(scores.shape) > 3: |
|
|
|
@ -87,29 +58,11 @@ class DiceBCELoss(nn.Layer): |
|
|
|
|
labels = labels.squeeze(1) |
|
|
|
|
if do_sigmoid: |
|
|
|
|
scores = paddle.nn.functional.sigmoid(scores.clone()) |
|
|
|
|
diceloss = self.binnary_dice(scores, labels) |
|
|
|
|
diceloss = self.binary_dice(scores, labels) |
|
|
|
|
bceloss = self.bce_loss(scores, labels) |
|
|
|
|
return diceloss + bceloss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class McDiceBCELoss(nn.Layer): |
|
|
|
|
"""Multi-class change detection task loss""" |
|
|
|
|
|
|
|
|
|
def __init__(self, weight, do_sigmoid=True): |
|
|
|
|
super(McDiceBCELoss, self).__init__() |
|
|
|
|
self.ce_loss = nn.CrossEntropyLoss(weight) |
|
|
|
|
self.dice = MultiClassDiceLoss(weight, do_sigmoid) |
|
|
|
|
|
|
|
|
|
def forward(self, scores, labels): |
|
|
|
|
if len(scores.shape) < 4: |
|
|
|
|
scores = scores.unsqueeze(1) |
|
|
|
|
if len(labels.shape) < 4: |
|
|
|
|
labels = labels.unsqueeze(1) |
|
|
|
|
diceloss = self.dice(scores, labels) |
|
|
|
|
bceloss = self.ce_loss(scores, labels) |
|
|
|
|
return diceloss + bceloss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fccdn_ssl_loss(logits_list, labels): |
|
|
|
|
""" |
|
|
|
|
Self-supervised learning loss for change detection. |
|
|
|
@ -160,11 +113,11 @@ def fccdn_ssl_loss(logits_list, labels): |
|
|
|
|
|
|
|
|
|
# Seg loss |
|
|
|
|
labels_downsample = labels_downsample.astype(paddle.float32) |
|
|
|
|
loss_aux = 0.2 * criterion_ssl(out1, pred_seg_post_tmp1, False) |
|
|
|
|
loss_aux += 0.2 * criterion_ssl(out2, pred_seg_pre_tmp1, False) |
|
|
|
|
loss_aux += 0.2 * criterion_ssl( |
|
|
|
|
out3, labels_downsample - pred_seg_post_tmp2, False) |
|
|
|
|
loss_aux += 0.2 * criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2, |
|
|
|
|
False) |
|
|
|
|
loss_aux = criterion_ssl(out1, pred_seg_post_tmp1, False) |
|
|
|
|
loss_aux += criterion_ssl(out2, pred_seg_pre_tmp1, False) |
|
|
|
|
loss_aux += criterion_ssl(out3, labels_downsample - pred_seg_post_tmp2, |
|
|
|
|
False) |
|
|
|
|
loss_aux += criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2, |
|
|
|
|
False) |
|
|
|
|
|
|
|
|
|
return loss_aux |
|
|
|
|