|
|
|
@ -110,11 +110,11 @@ class SegLoss(Loss): |
|
|
|
|
target_scores, target_scores_sum, fg_mask) |
|
|
|
|
for i in range(batch_size): |
|
|
|
|
if fg_mask[i].sum(): |
|
|
|
|
mask_idx = target_gt_idx[i][fg_mask[i]] + 1 |
|
|
|
|
mask_idx = target_gt_idx[i][fg_mask[i]] |
|
|
|
|
if self.overlap: |
|
|
|
|
gt_mask = torch.where(masks[[i]] == mask_idx.view(-1, 1, 1), 1.0, 0.0) |
|
|
|
|
gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0) |
|
|
|
|
else: |
|
|
|
|
gt_mask = masks[batch_idx == i][mask_idx] |
|
|
|
|
gt_mask = masks[batch_idx.view(-1) == i][mask_idx] |
|
|
|
|
xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]] |
|
|
|
|
marea = xyxy2xywh(xyxyn)[:, 2:].prod(1) |
|
|
|
|
mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device) |
|
|
|
|