|
|
|
@ -58,17 +58,45 @@ class TaskAlignedAssigner(nn.Module): |
|
|
|
|
""" |
|
|
|
|
self.bs = pd_scores.shape[0] |
|
|
|
|
self.n_max_boxes = gt_bboxes.shape[1] |
|
|
|
|
device = gt_bboxes.device |
|
|
|
|
|
|
|
|
|
if self.n_max_boxes == 0: |
|
|
|
|
device = gt_bboxes.device |
|
|
|
|
return ( |
|
|
|
|
torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), |
|
|
|
|
torch.zeros_like(pd_bboxes).to(device), |
|
|
|
|
torch.zeros_like(pd_scores).to(device), |
|
|
|
|
torch.zeros_like(pd_scores[..., 0]).to(device), |
|
|
|
|
torch.zeros_like(pd_scores[..., 0]).to(device), |
|
|
|
|
torch.full_like(pd_scores[..., 0], self.bg_idx), |
|
|
|
|
torch.zeros_like(pd_bboxes), |
|
|
|
|
torch.zeros_like(pd_scores), |
|
|
|
|
torch.zeros_like(pd_scores[..., 0]), |
|
|
|
|
torch.zeros_like(pd_scores[..., 0]), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt) |
|
|
|
|
except torch.OutOfMemoryError: |
|
|
|
|
# Move tensors to CPU, compute, then move back to original device |
|
|
|
|
cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)] |
|
|
|
|
result = self._forward(*cpu_tensors) |
|
|
|
|
return tuple(t.to(device) for t in result) |
|
|
|
|
|
|
|
|
|
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): |
|
|
|
|
""" |
|
|
|
|
Compute the task-aligned assignment. Reference code is available at |
|
|
|
|
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes) |
|
|
|
|
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4) |
|
|
|
|
anc_points (Tensor): shape(num_total_anchors, 2) |
|
|
|
|
gt_labels (Tensor): shape(bs, n_max_boxes, 1) |
|
|
|
|
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4) |
|
|
|
|
mask_gt (Tensor): shape(bs, n_max_boxes, 1) |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
target_labels (Tensor): shape(bs, num_total_anchors) |
|
|
|
|
target_bboxes (Tensor): shape(bs, num_total_anchors, 4) |
|
|
|
|
target_scores (Tensor): shape(bs, num_total_anchors, num_classes) |
|
|
|
|
fg_mask (Tensor): shape(bs, num_total_anchors) |
|
|
|
|
target_gt_idx (Tensor): shape(bs, num_total_anchors) |
|
|
|
|
""" |
|
|
|
|
mask_pos, align_metric, overlaps = self.get_pos_mask( |
|
|
|
|
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt |
|
|
|
|
) |
|
|
|
|