diff --git a/paddlers/rs_models/cd/losses/fccdn_loss.py b/paddlers/rs_models/cd/losses/fccdn_loss.py index 49d2b4c..259367f 100644 --- a/paddlers/rs_models/cd/losses/fccdn_loss.py +++ b/paddlers/rs_models/cd/losses/fccdn_loss.py @@ -43,42 +43,13 @@ class DiceLoss(nn.Layer): return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true) -class MultiClassDiceLoss(nn.Layer): - def __init__( - self, - weight, - batch=True, - ignore_index=-1, - do_softmax=False, - **kwargs, ): - super(MultiClassDiceLoss, self).__init__() - self.ignore_index = ignore_index - self.weight = weight - self.do_softmax = do_softmax - self.binary_diceloss = DiceLoss(batch) - - def forward(self, y_pred, y_true): - if self.do_softmax: - y_pred = paddle.nn.functional.softmax(y_pred, axis=1) - y_true = F.one_hot(y_true.long(), y_pred.shape[1]).permute(0, 3, 1, 2) - total_loss = 0.0 - tmp_i = 0.0 - for i in range(y_pred.shape[1]): - if i != self.ignore_index: - diceloss = self.binary_diceloss(y_pred[:, i, :, :], - y_true[:, i, :, :]) - total_loss += paddle.multiply(diceloss, self.weight[i]) - tmp_i += 1.0 - return total_loss / tmp_i - - class DiceBCELoss(nn.Layer): """Binary change detection task loss""" def __init__(self): super(DiceBCELoss, self).__init__() self.bce_loss = nn.BCELoss() - self.binnary_dice = DiceLoss() + self.binary_dice = DiceLoss() def forward(self, scores, labels, do_sigmoid=True): if len(scores.shape) > 3: @@ -87,29 +58,11 @@ class DiceBCELoss(nn.Layer): labels = labels.squeeze(1) if do_sigmoid: scores = paddle.nn.functional.sigmoid(scores.clone()) - diceloss = self.binnary_dice(scores, labels) + diceloss = self.binary_dice(scores, labels) bceloss = self.bce_loss(scores, labels) return diceloss + bceloss -class McDiceBCELoss(nn.Layer): - """Multi-class change detection task loss""" - - def __init__(self, weight, do_sigmoid=True): - super(McDiceBCELoss, self).__init__() - self.ce_loss = nn.CrossEntropyLoss(weight) - self.dice = MultiClassDiceLoss(weight, do_sigmoid) - - def forward(self, scores, labels): - if len(scores.shape) < 4: - scores = scores.unsqueeze(1) - if len(labels.shape) < 4: - labels = labels.unsqueeze(1) - diceloss = self.dice(scores, labels) - bceloss = self.ce_loss(scores, labels) - return diceloss + bceloss - - def fccdn_ssl_loss(logits_list, labels): """ Self-supervised learning loss for change detection. @@ -160,11 +113,11 @@ def fccdn_ssl_loss(logits_list, labels): # Seg loss labels_downsample = labels_downsample.astype(paddle.float32) - loss_aux = 0.2 * criterion_ssl(out1, pred_seg_post_tmp1, False) - loss_aux += 0.2 * criterion_ssl(out2, pred_seg_pre_tmp1, False) - loss_aux += 0.2 * criterion_ssl( - out3, labels_downsample - pred_seg_post_tmp2, False) - loss_aux += 0.2 * criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2, - False) + loss_aux = criterion_ssl(out1, pred_seg_post_tmp1, False) + loss_aux += criterion_ssl(out2, pred_seg_pre_tmp1, False) + loss_aux += criterion_ssl(out3, labels_downsample - pred_seg_post_tmp2, + False) + loss_aux += criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2, + False) return loss_aux diff --git a/paddlers/tasks/change_detector.py b/paddlers/tasks/change_detector.py index 6df35d8..6f6cb69 100644 --- a/paddlers/tasks/change_detector.py +++ b/paddlers/tasks/change_detector.py @@ -1067,7 +1067,7 @@ class FCCDN(BaseChangeDetector): return { 'types': [seg_losses.CrossEntropyLoss(), cmcd.losses.fccdn_ssl_loss], - 'coef': [1.0, 1.0] + 'coef': [1.0, 0.2] } else: raise ValueError( diff --git a/tutorials/train/change_detection/fccdn.py b/tutorials/train/change_detection/fccdn.py index 318fa0e..7ac6fa7 100644 --- a/tutorials/train/change_detection/fccdn.py +++ b/tutorials/train/change_detection/fccdn.py @@ -78,7 +78,7 @@ model = pdrs.tasks.cd.FCCDN() # 执行模型训练 model.train( - num_epochs=10, + num_epochs=15, train_dataset=train_dataset, train_batch_size=4, eval_dataset=eval_dataset,