|
|
|
@ -10,7 +10,7 @@ from .metrics import bbox_iou |
|
|
|
|
TORCH_1_10 = check_version(torch.__version__, '1.10.0') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): |
|
|
|
|
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9, roll_out=False): |
|
|
|
|
"""select the positive anchor center in gt |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
@ -21,10 +21,18 @@ def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): |
|
|
|
|
""" |
|
|
|
|
n_anchors = xy_centers.shape[0] |
|
|
|
|
bs, n_boxes, _ = gt_bboxes.shape |
|
|
|
|
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom |
|
|
|
|
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) |
|
|
|
|
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype) |
|
|
|
|
return bbox_deltas.amin(3).gt_(eps) |
|
|
|
|
if roll_out: |
|
|
|
|
bbox_deltas = torch.empty((bs, n_boxes, n_anchors), device=gt_bboxes.device) |
|
|
|
|
for b in range(bs): |
|
|
|
|
lt, rb = gt_bboxes[b].view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom |
|
|
|
|
bbox_deltas[b] = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), |
|
|
|
|
dim=2).view(n_boxes, n_anchors, -1).amin(2).gt_(eps) |
|
|
|
|
return bbox_deltas |
|
|
|
|
else: |
|
|
|
|
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom |
|
|
|
|
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) |
|
|
|
|
# return (bbox_deltas.min(3)[0] > eps).to(gt_bboxes.dtype) |
|
|
|
|
return bbox_deltas.amin(3).gt_(eps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): |
|
|
|
@ -55,7 +63,7 @@ def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): |
|
|
|
|
|
|
|
|
|
class TaskAlignedAssigner(nn.Module): |
|
|
|
|
|
|
|
|
|
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9): |
|
|
|
|
def __init__(self, topk=13, num_classes=80, alpha=1.0, beta=6.0, eps=1e-9, roll_out_thr=0): |
|
|
|
|
super().__init__() |
|
|
|
|
self.topk = topk |
|
|
|
|
self.num_classes = num_classes |
|
|
|
@ -63,6 +71,7 @@ class TaskAlignedAssigner(nn.Module): |
|
|
|
|
self.alpha = alpha |
|
|
|
|
self.beta = beta |
|
|
|
|
self.eps = eps |
|
|
|
|
self.roll_out_thr = roll_out_thr |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): |
|
|
|
@ -84,6 +93,7 @@ class TaskAlignedAssigner(nn.Module): |
|
|
|
|
""" |
|
|
|
|
self.bs = pd_scores.size(0) |
|
|
|
|
self.n_max_boxes = gt_bboxes.size(1) |
|
|
|
|
self.roll_out = self.n_max_boxes > self.roll_out_thr if self.roll_out_thr else False |
|
|
|
|
|
|
|
|
|
if self.n_max_boxes == 0: |
|
|
|
|
device = gt_bboxes.device |
|
|
|
@ -112,7 +122,7 @@ class TaskAlignedAssigner(nn.Module): |
|
|
|
|
# get anchor_align metric, (b, max_num_obj, h*w) |
|
|
|
|
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes) |
|
|
|
|
# get in_gts mask, (b, max_num_obj, h*w) |
|
|
|
|
mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes) |
|
|
|
|
mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes, roll_out=self.roll_out) |
|
|
|
|
# get topk_metric mask, (b, max_num_obj, h*w) |
|
|
|
|
mask_topk = self.select_topk_candidates(align_metric * mask_in_gts, |
|
|
|
|
topk_mask=mask_gt.repeat([1, 1, self.topk]).bool()) |
|
|
|
@ -122,14 +132,27 @@ class TaskAlignedAssigner(nn.Module): |
|
|
|
|
return mask_pos, align_metric, overlaps |
|
|
|
|
|
|
|
|
|
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes): |
|
|
|
|
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj |
|
|
|
|
ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj |
|
|
|
|
ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj |
|
|
|
|
# get the scores of each grid for each gt cls |
|
|
|
|
bbox_scores = pd_scores[ind[0], :, ind[1]] # b, max_num_obj, h*w |
|
|
|
|
|
|
|
|
|
overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False, CIoU=True).squeeze(3).clamp(0) |
|
|
|
|
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) |
|
|
|
|
if self.roll_out: |
|
|
|
|
align_metric = torch.empty((self.bs, self.n_max_boxes, pd_scores.shape[1]), device=pd_scores.device) |
|
|
|
|
overlaps = torch.empty((self.bs, self.n_max_boxes, pd_scores.shape[1]), device=pd_scores.device) |
|
|
|
|
ind_0 = torch.empty(self.n_max_boxes, dtype=torch.long) |
|
|
|
|
for b in range(self.bs): |
|
|
|
|
ind_0[:], ind_2 = b, gt_labels[b].squeeze(-1).long() |
|
|
|
|
# get the scores of each grid for each gt cls |
|
|
|
|
bbox_scores = pd_scores[ind_0, :, ind_2] # b, max_num_obj, h*w |
|
|
|
|
overlaps[b] = bbox_iou(gt_bboxes[b].unsqueeze(1), pd_bboxes[b].unsqueeze(0), xywh=False, |
|
|
|
|
CIoU=True).squeeze(2).clamp(0) |
|
|
|
|
align_metric[b] = bbox_scores.pow(self.alpha) * overlaps[b].pow(self.beta) |
|
|
|
|
else: |
|
|
|
|
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj |
|
|
|
|
ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj |
|
|
|
|
ind[1] = gt_labels.long().squeeze(-1) # b, max_num_obj |
|
|
|
|
# get the scores of each grid for each gt cls |
|
|
|
|
bbox_scores = pd_scores[ind[0], :, ind[1]] # b, max_num_obj, h*w |
|
|
|
|
|
|
|
|
|
overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False, |
|
|
|
|
CIoU=True).squeeze(3).clamp(0) |
|
|
|
|
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) |
|
|
|
|
return align_metric, overlaps |
|
|
|
|
|
|
|
|
|
def select_topk_candidates(self, metrics, largest=True, topk_mask=None): |
|
|
|
@ -145,9 +168,14 @@ class TaskAlignedAssigner(nn.Module): |
|
|
|
|
if topk_mask is None: |
|
|
|
|
topk_mask = (topk_metrics.max(-1, keepdim=True) > self.eps).tile([1, 1, self.topk]) |
|
|
|
|
# (b, max_num_obj, topk) |
|
|
|
|
topk_idxs = torch.where(topk_mask, topk_idxs, 0) |
|
|
|
|
topk_idxs[~topk_mask] = 0 |
|
|
|
|
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w) |
|
|
|
|
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2) |
|
|
|
|
if self.roll_out: |
|
|
|
|
is_in_topk = torch.empty(metrics.shape, dtype=torch.long, device=metrics.device) |
|
|
|
|
for b in range(len(topk_idxs)): |
|
|
|
|
is_in_topk[b] = F.one_hot(topk_idxs[b], num_anchors).sum(-2) |
|
|
|
|
else: |
|
|
|
|
is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2) |
|
|
|
|
# filter invalid bboxes |
|
|
|
|
is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk) |
|
|
|
|
return is_in_topk.to(metrics.dtype) |
|
|
|
|