diff --git a/mmdet/models/dense_heads/paa_head.py b/mmdet/models/dense_heads/paa_head.py index 85a80e844..90fad1147 100644 --- a/mmdet/models/dense_heads/paa_head.py +++ b/mmdet/models/dense_heads/paa_head.py @@ -120,7 +120,6 @@ class PAAHead(ATSSHead): bbox_preds = [item.reshape(-1, 4) for item in bbox_preds] iou_preds = levels_to_images(iou_preds) iou_preds = [item.reshape(-1, 1) for item in iou_preds] - pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list, cls_scores, bbox_preds, labels, labels_weight, bboxes_target, @@ -138,6 +137,8 @@ class PAAHead(ATSSHead): anchor_list, ) num_pos = sum(num_pos) + if num_pos == 0: + num_pos = len(img_metas) # convert all tensor list to a flatten tensor cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1)) bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1)) @@ -203,6 +204,8 @@ class PAAHead(ATSSHead): Returns: Tensor: Losses of all positive samples in single image. """ + if not len(pos_inds): + return cls_score.new([]), anchors_all_level = torch.cat(anchors, 0) pos_scores = cls_score[pos_inds] pos_bbox_pred = bbox_pred[pos_inds] @@ -359,6 +362,8 @@ class PAAHead(ATSSHead): # https://github.com/kkhoot/PAA/issues/8 and # https://github.com/kkhoot/PAA/issues/9. fgs = gmm_assignment == 0 + pos_inds_temp = fgs.new_tensor([], dtype=torch.long) + ignore_inds_temp = fgs.new_tensor([], dtype=torch.long) if fgs.nonzero().numel(): _, pos_thr_ind = scores[fgs].topk(1) pos_inds_temp = pos_inds_gmm[fgs][:pos_thr_ind + 1]