|
|
|
@ -120,7 +120,6 @@ class PAAHead(ATSSHead): |
|
|
|
|
bbox_preds = [item.reshape(-1, 4) for item in bbox_preds] |
|
|
|
|
iou_preds = levels_to_images(iou_preds) |
|
|
|
|
iou_preds = [item.reshape(-1, 1) for item in iou_preds] |
|
|
|
|
|
|
|
|
|
pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list, |
|
|
|
|
cls_scores, bbox_preds, labels, |
|
|
|
|
labels_weight, bboxes_target, |
|
|
|
@ -138,6 +137,8 @@ class PAAHead(ATSSHead): |
|
|
|
|
anchor_list, |
|
|
|
|
) |
|
|
|
|
num_pos = sum(num_pos) |
|
|
|
|
if num_pos == 0: |
|
|
|
|
num_pos = len(img_metas) |
|
|
|
|
# convert all tensor list to a flatten tensor |
|
|
|
|
cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1)) |
|
|
|
|
bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1)) |
|
|
|
@ -203,6 +204,8 @@ class PAAHead(ATSSHead): |
|
|
|
|
Returns: |
|
|
|
|
Tensor: Losses of all positive samples in single image. |
|
|
|
|
""" |
|
|
|
|
if not len(pos_inds): |
|
|
|
|
return cls_score.new([]), |
|
|
|
|
anchors_all_level = torch.cat(anchors, 0) |
|
|
|
|
pos_scores = cls_score[pos_inds] |
|
|
|
|
pos_bbox_pred = bbox_pred[pos_inds] |
|
|
|
@ -359,6 +362,8 @@ class PAAHead(ATSSHead): |
|
|
|
|
# https://github.com/kkhoot/PAA/issues/8 and |
|
|
|
|
# https://github.com/kkhoot/PAA/issues/9. |
|
|
|
|
fgs = gmm_assignment == 0 |
|
|
|
|
pos_inds_temp = fgs.new_tensor([], dtype=torch.long) |
|
|
|
|
ignore_inds_temp = fgs.new_tensor([], dtype=torch.long) |
|
|
|
|
if fgs.nonzero().numel(): |
|
|
|
|
_, pos_thr_ind = scores[fgs].topk(1) |
|
|
|
|
pos_inds_temp = pos_inds_gmm[fgs][:pos_thr_ind + 1] |
|
|
|
|