Remove unused classes in fccdn

own
Bobholamovic 2 years ago
parent 9c1b2ea2fe
commit 92a5086c79
  1. 63
      paddlers/rs_models/cd/losses/fccdn_loss.py
  2. 2
      paddlers/tasks/change_detector.py
  3. 2
      tutorials/train/change_detection/fccdn.py

@ -43,42 +43,13 @@ class DiceLoss(nn.Layer):
return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true) return self.soft_dice_loss(y_pred.astype(paddle.float32), y_true)
class MultiClassDiceLoss(nn.Layer):
def __init__(
self,
weight,
batch=True,
ignore_index=-1,
do_softmax=False,
**kwargs, ):
super(MultiClassDiceLoss, self).__init__()
self.ignore_index = ignore_index
self.weight = weight
self.do_softmax = do_softmax
self.binary_diceloss = DiceLoss(batch)
def forward(self, y_pred, y_true):
if self.do_softmax:
y_pred = paddle.nn.functional.softmax(y_pred, axis=1)
y_true = F.one_hot(y_true.long(), y_pred.shape[1]).permute(0, 3, 1, 2)
total_loss = 0.0
tmp_i = 0.0
for i in range(y_pred.shape[1]):
if i != self.ignore_index:
diceloss = self.binary_diceloss(y_pred[:, i, :, :],
y_true[:, i, :, :])
total_loss += paddle.multiply(diceloss, self.weight[i])
tmp_i += 1.0
return total_loss / tmp_i
class DiceBCELoss(nn.Layer): class DiceBCELoss(nn.Layer):
"""Binary change detection task loss""" """Binary change detection task loss"""
def __init__(self): def __init__(self):
super(DiceBCELoss, self).__init__() super(DiceBCELoss, self).__init__()
self.bce_loss = nn.BCELoss() self.bce_loss = nn.BCELoss()
self.binnary_dice = DiceLoss() self.binary_dice = DiceLoss()
def forward(self, scores, labels, do_sigmoid=True): def forward(self, scores, labels, do_sigmoid=True):
if len(scores.shape) > 3: if len(scores.shape) > 3:
@ -87,29 +58,11 @@ class DiceBCELoss(nn.Layer):
labels = labels.squeeze(1) labels = labels.squeeze(1)
if do_sigmoid: if do_sigmoid:
scores = paddle.nn.functional.sigmoid(scores.clone()) scores = paddle.nn.functional.sigmoid(scores.clone())
diceloss = self.binnary_dice(scores, labels) diceloss = self.binary_dice(scores, labels)
bceloss = self.bce_loss(scores, labels) bceloss = self.bce_loss(scores, labels)
return diceloss + bceloss return diceloss + bceloss
class McDiceBCELoss(nn.Layer):
"""Multi-class change detection task loss"""
def __init__(self, weight, do_sigmoid=True):
super(McDiceBCELoss, self).__init__()
self.ce_loss = nn.CrossEntropyLoss(weight)
self.dice = MultiClassDiceLoss(weight, do_sigmoid)
def forward(self, scores, labels):
if len(scores.shape) < 4:
scores = scores.unsqueeze(1)
if len(labels.shape) < 4:
labels = labels.unsqueeze(1)
diceloss = self.dice(scores, labels)
bceloss = self.ce_loss(scores, labels)
return diceloss + bceloss
def fccdn_ssl_loss(logits_list, labels): def fccdn_ssl_loss(logits_list, labels):
""" """
Self-supervised learning loss for change detection. Self-supervised learning loss for change detection.
@ -160,11 +113,11 @@ def fccdn_ssl_loss(logits_list, labels):
# Seg loss # Seg loss
labels_downsample = labels_downsample.astype(paddle.float32) labels_downsample = labels_downsample.astype(paddle.float32)
loss_aux = 0.2 * criterion_ssl(out1, pred_seg_post_tmp1, False) loss_aux = criterion_ssl(out1, pred_seg_post_tmp1, False)
loss_aux += 0.2 * criterion_ssl(out2, pred_seg_pre_tmp1, False) loss_aux += criterion_ssl(out2, pred_seg_pre_tmp1, False)
loss_aux += 0.2 * criterion_ssl( loss_aux += criterion_ssl(out3, labels_downsample - pred_seg_post_tmp2,
out3, labels_downsample - pred_seg_post_tmp2, False) False)
loss_aux += 0.2 * criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2, loss_aux += criterion_ssl(out4, labels_downsample - pred_seg_pre_tmp2,
False) False)
return loss_aux return loss_aux

@ -1067,7 +1067,7 @@ class FCCDN(BaseChangeDetector):
return { return {
'types': 'types':
[seg_losses.CrossEntropyLoss(), cmcd.losses.fccdn_ssl_loss], [seg_losses.CrossEntropyLoss(), cmcd.losses.fccdn_ssl_loss],
'coef': [1.0, 1.0] 'coef': [1.0, 0.2]
} }
else: else:
raise ValueError( raise ValueError(

@ -78,7 +78,7 @@ model = pdrs.tasks.cd.FCCDN()
# 执行模型训练 # 执行模型训练
model.train( model.train(
num_epochs=10, num_epochs=15,
train_dataset=train_dataset, train_dataset=train_dataset,
train_batch_size=4, train_batch_size=4,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,

Loading…
Cancel
Save