fix no gt in gpu focal loss (#3688)

* fix no gt in gpu focal loss

* fix num_pos==0

* fix gmm_separation_scheme
pull/3795/head
shilong 5 years ago committed by GitHub
parent 927d71a98c
commit 9f22edcea4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      mmdet/models/dense_heads/paa_head.py

@ -120,7 +120,6 @@ class PAAHead(ATSSHead):
bbox_preds = [item.reshape(-1, 4) for item in bbox_preds]
iou_preds = levels_to_images(iou_preds)
iou_preds = [item.reshape(-1, 1) for item in iou_preds]
pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list,
cls_scores, bbox_preds, labels,
labels_weight, bboxes_target,
@ -138,6 +137,8 @@ class PAAHead(ATSSHead):
anchor_list,
)
num_pos = sum(num_pos)
if num_pos == 0:
num_pos = len(img_metas)
# convert all tensor list to a flatten tensor
cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1))
bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1))
@ -203,6 +204,8 @@ class PAAHead(ATSSHead):
Returns:
Tensor: Losses of all positive samples in single image.
"""
if not len(pos_inds):
return cls_score.new([]),
anchors_all_level = torch.cat(anchors, 0)
pos_scores = cls_score[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
@ -359,6 +362,8 @@ class PAAHead(ATSSHead):
# https://github.com/kkhoot/PAA/issues/8 and
# https://github.com/kkhoot/PAA/issues/9.
fgs = gmm_assignment == 0
pos_inds_temp = fgs.new_tensor([], dtype=torch.long)
ignore_inds_temp = fgs.new_tensor([], dtype=torch.long)
if fgs.nonzero().numel():
_, pos_thr_ind = scores[fgs].topk(1)
pos_inds_temp = pos_inds_gmm[fgs][:pos_thr_ind + 1]

Loading…
Cancel
Save