diff --git a/configs/paa/README.md b/configs/paa/README.md new file mode 100644 index 000000000..5443b39a8 --- /dev/null +++ b/configs/paa/README.md @@ -0,0 +1,22 @@ +# Probabilistic Anchor Assignment with IoU Prediction for Object Detection + + + +## Results and Models +We provide config files to reproduce the object detection results in the +ECCV 2020 paper for Probabilistic Anchor Assignment with IoU +Prediction for Object Detection. + +| Backbone | Lr schd | Mem (GB) | Score voting | box AP | Download | +|:-----------:|:-------:|:--------:|:------------:|:------:|:--------:| +| R-50-FPN | 12e | 3.7 | True | 40.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_20200821-936edec3.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_1x_20200821-936edec3.log.json) | +| R-50-FPN | 12e | 3.7 | False | 40.2 | - | +| R-50-FPN | 18e | 3.7 | True | 41.4 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_1.5x_20200823-805d6078.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_1.5x_20200823-805d6078.log.json) | +| R-50-FPN | 18e | 3.7 | False | 41.2 | - | +| R-50-FPN | 24e | 3.7 | True | 41.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_2x_20200821-c98bfc4e.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r50_fpn_2x_20200821-c98bfc4e.log.json) | +| R-101-FPN | 12e | 6.2 | True | 42.6 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_20200821-0a1825a4.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r101_fpn_1x_20200821-0a1825a4.log.json) | +| R-101-FPN | 12e | 6.2 | False | 42.4 | - | +| R-101-FPN | 24e | 6.2 | True | 43.5 | [model](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r101_fpn_2x_20200821-6829f96b.pth) | [log](https://open-mmlab.s3.ap-northeast-2.amazonaws.com/mmdetection/v2.0/paa/paa_r101_fpn_2x_20200821-6829f96b.log.json) | + +**Note**: +1. We find that the performance is unstable with 1x setting and may fluctuate by about 0.2 mAP. We report the best results. diff --git a/configs/paa/paa_r101_fpn_1x_coco.py b/configs/paa/paa_r101_fpn_1x_coco.py new file mode 100644 index 000000000..a64a012dd --- /dev/null +++ b/configs/paa/paa_r101_fpn_1x_coco.py @@ -0,0 +1,4 @@ +_base_ = './paa_r50_fpn_1x_coco.py' +model = dict(pretrained='torchvision://resnet101', backbone=dict(depth=101)) +lr_config = dict(step=[16, 22]) +total_epochs = 24 diff --git a/configs/paa/paa_r101_fpn_2x_coco.py b/configs/paa/paa_r101_fpn_2x_coco.py new file mode 100644 index 000000000..a3bc60f91 --- /dev/null +++ b/configs/paa/paa_r101_fpn_2x_coco.py @@ -0,0 +1,3 @@ +_base_ = './paa_r101_fpn_1x_coco.py' +lr_config = dict(step=[16, 22]) +total_epochs = 24 diff --git a/configs/paa/paa_r50_fpn_1.5x_coco.py b/configs/paa/paa_r50_fpn_1.5x_coco.py new file mode 100644 index 000000000..7de45783b --- /dev/null +++ b/configs/paa/paa_r50_fpn_1.5x_coco.py @@ -0,0 +1,3 @@ +_base_ = './paa_r50_fpn_1x_coco.py' +lr_config = dict(step=[12, 16]) +total_epochs = 18 diff --git a/configs/paa/paa_r50_fpn_1x_coco.py b/configs/paa/paa_r50_fpn_1x_coco.py new file mode 100644 index 000000000..e66cd1b77 --- /dev/null +++ b/configs/paa/paa_r50_fpn_1x_coco.py @@ -0,0 +1,70 @@ +_base_ = [ + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +model = dict( + type='PAA', + pretrained='torchvision://resnet50', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='PAAHead', + reg_decoded_bbox=True, + score_voting=True, + topk=9, + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.3), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5))) +# training and testing settings +train_cfg = dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.1, + neg_iou_thr=0.1, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False) +test_cfg = dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100) +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) diff --git a/configs/paa/paa_r50_fpn_2x_coco.py b/configs/paa/paa_r50_fpn_2x_coco.py new file mode 100644 index 000000000..529f07439 --- /dev/null +++ b/configs/paa/paa_r50_fpn_2x_coco.py @@ -0,0 +1,3 @@ +_base_ = './paa_r50_fpn_1x_coco.py' +lr_config = dict(step=[16, 22]) +total_epochs = 24 diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index 8719e109e..9f5ec181a 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -11,6 +11,7 @@ from .ga_rpn_head import GARPNHead from .gfl_head import GFLHead from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead from .nasfcos_head import NASFCOSHead +from .paa_head import PAAHead from .pisa_retinanet_head import PISARetinaHead from .pisa_ssd_head import PISASSDHead from .reppoints_head import RepPointsHead @@ -24,5 +25,5 @@ __all__ = [ 'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead', 'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead', 'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead', - 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead' + 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'PAAHead' ] diff --git a/mmdet/models/dense_heads/paa_head.py b/mmdet/models/dense_heads/paa_head.py new file mode 100644 index 000000000..85a80e844 --- /dev/null +++ b/mmdet/models/dense_heads/paa_head.py @@ -0,0 +1,620 @@ +import numpy as np +import torch + +from mmdet.core import force_fp32, multi_apply, multiclass_nms +from mmdet.core.bbox.iou_calculators import bbox_overlaps +from mmdet.models import HEADS +from mmdet.models.dense_heads import ATSSHead + +eps = 1e-12 +try: + import sklearn.mixture as skm +except ImportError: + skm = None + + +def levels_to_images(mlvl_tensor): + """Concat multi-level feature maps by image. + + [feature_level0, feature_level1...] -> [feature_image0, feature_image1...] + Convert the shape of each element in mlvl_tensor from (N, C, H, W) to + (N, H*W , C), then split the element to N elements with shape (H*W, C), and + concat elements in same image of all level along first dimension. + + Args: + mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from + corresponding level. Each element is of shape (N, C, H, W) + + Returns: + list[torch.Tensor]: A list that contains N tensors and each tensor is + of shape (num_elements, C) + """ + batch_size = mlvl_tensor[0].size(0) + batch_list = [[] for _ in range(batch_size)] + channels = mlvl_tensor[0].size(1) + for t in mlvl_tensor: + t = t.permute(0, 2, 3, 1) + t = t.view(batch_size, -1, channels).contiguous() + for img in range(batch_size): + batch_list[img].append(t[img]) + return [torch.cat(item, 0) for item in batch_list] + + +@HEADS.register_module() +class PAAHead(ATSSHead): + """Head of PAAAssignment: Probabilistic Anchor Assignment with IoU + Prediction for Object Detection. + + Code is modified from the `official github repo + `_. + + More details can be found in the `paper + `_ . + + Args: + topk (int): Select topk samples with smallest loss in + each level. + score_voting (bool): Whether to use score voting in post-process. + """ + + def __init__(self, *args, topk=9, score_voting=True, **kwargs): + # topk used in paa reassign process + self.topk = topk + self.with_score_voting = score_voting + super(PAAHead, self).__init__(*args, **kwargs) + + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds')) + def loss(self, + cls_scores, + bbox_preds, + iou_preds, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + iou_preds (list[Tensor]): iou_preds for each scale + level with shape (N, num_anchors * 1, H, W) + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (list[Tensor] | None): Specify which bounding + boxes can be ignored when are computing the loss. + + Returns: + dict[str, Tensor]: A dictionary of loss gmm_assignment. + """ + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.anchor_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_metas, device=device) + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=label_channels, + ) + (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds, + pos_gt_index) = cls_reg_targets + cls_scores = levels_to_images(cls_scores) + cls_scores = [ + item.reshape(-1, self.cls_out_channels) for item in cls_scores + ] + bbox_preds = levels_to_images(bbox_preds) + bbox_preds = [item.reshape(-1, 4) for item in bbox_preds] + iou_preds = levels_to_images(iou_preds) + iou_preds = [item.reshape(-1, 1) for item in iou_preds] + + pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list, + cls_scores, bbox_preds, labels, + labels_weight, bboxes_target, + bboxes_weight, pos_inds) + + with torch.no_grad(): + labels, label_weights, bbox_weights, num_pos = multi_apply( + self.paa_reassign, + pos_losses_list, + labels, + labels_weight, + bboxes_weight, + pos_inds, + pos_gt_index, + anchor_list, + ) + num_pos = sum(num_pos) + # convert all tensor list to a flatten tensor + cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1)) + bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1)) + iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1)) + labels = torch.cat(labels, 0).view(-1) + flatten_anchors = torch.cat( + [torch.cat(item, 0) for item in anchor_list]) + labels_weight = torch.cat(labels_weight, 0).view(-1) + bboxes_target = torch.cat(bboxes_target, + 0).view(-1, bboxes_target[0].size(-1)) + + pos_inds_flatten = ( + (labels >= 0) + & (labels < self.background_label)).nonzero().reshape(-1) + + losses_cls = self.loss_cls( + cls_scores, labels, labels_weight, avg_factor=num_pos) + if num_pos: + pos_bbox_pred = self.bbox_coder.decode( + flatten_anchors[pos_inds_flatten], + bbox_preds[pos_inds_flatten]) + pos_bbox_target = bboxes_target[pos_inds_flatten] + iou_target = bbox_overlaps( + pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True) + losses_iou = self.loss_centerness( + iou_preds[pos_inds_flatten], + iou_target.unsqueeze(-1), + avg_factor=num_pos) + losses_bbox = self.loss_bbox( + pos_bbox_pred, + pos_bbox_target, + iou_target.clamp(min=eps), + avg_factor=iou_target.sum()) + else: + losses_iou = iou_preds.sum() * 0 + losses_bbox = bbox_preds.sum() * 0 + + return dict( + loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou) + + def get_pos_loss(self, anchors, cls_score, bbox_pred, label, label_weight, + bbox_target, bbox_weight, pos_inds): + """Calculate loss of all potential positive samples obtained from first + match process. + + Args: + anchors (list[Tensor]): Anchors of each scale. + cls_score (Tensor): Box scores of single image with shape + (num_anchors, num_classes) + bbox_pred (Tensor): Box energies / deltas of single image + with shape (num_anchors, 4) + label (Tensor): classification target of each anchor with + shape (num_anchors,) + label_weight (Tensor): Classification loss weight of each + anchor with shape (num_anchors). + bbox_target (dict): Regression target of each anchor with + shape (num_anchors, 4). + bbox_weight (Tensor): Bbox weight of each anchor with shape + (num_anchors, 4). + pos_inds (Tensor): Index of all positive samples got from + first assign process. + + Returns: + Tensor: Losses of all positive samples in single image. + """ + anchors_all_level = torch.cat(anchors, 0) + pos_scores = cls_score[pos_inds] + pos_bbox_pred = bbox_pred[pos_inds] + pos_label = label[pos_inds] + pos_label_weight = label_weight[pos_inds] + pos_bbox_target = bbox_target[pos_inds] + pos_bbox_weight = bbox_weight[pos_inds] + pos_anchors = anchors_all_level[pos_inds] + pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred) + + # to keep loss dimension + loss_cls = self.loss_cls( + pos_scores, + pos_label, + pos_label_weight, + avg_factor=self.loss_cls.loss_weight, + reduction_override='none') + + loss_bbox = self.loss_bbox( + pos_bbox_pred, + pos_bbox_target, + pos_bbox_weight, + avg_factor=self.loss_cls.loss_weight, + reduction_override='none') + + loss_cls = loss_cls.sum(-1) + pos_loss = loss_bbox + loss_cls + return pos_loss, + + def paa_reassign(self, pos_losses, label, label_weight, bbox_weight, + pos_inds, pos_gt_inds, anchors): + """Fit loss to GMM distribution and separate positive, ignore, negative + samples again with GMM model. + + Args: + pos_losses (Tensor): Losses of all positive samples in + single image. + label (Tensor): classification target of each anchor with + shape (num_anchors,) + label_weight (Tensor): Classification loss weight of each + anchor with shape (num_anchors). + bbox_weight (Tensor): Bbox weight of each anchor with shape + (num_anchors, 4). + pos_inds (Tensor): Index of all positive samples got from + first assign process. + pos_gt_inds (Tensor): Gt_index of all positive samples got + from first assign process. + anchors (list[Tensor]): Anchors of each scale. + + Returns: + tuple: Usually returns a tuple containing learning targets. + + - label (Tensor): classification target of each anchor after + paa assign, with shape (num_anchors,) + - label_weight (Tensor): Classification loss weight of each + anchor after paa assign, with shape (num_anchors). + - bbox_weight (Tensor): Bbox weight of each anchor with shape + (num_anchors, 4). + - num_pos (int): The number of positive samples after paa + assign. + """ + if not len(pos_inds): + return label, label_weight, bbox_weight, 0 + + num_gt = pos_gt_inds.max() + 1 + num_level = len(anchors) + num_anchors_each_level = [item.size(0) for item in anchors] + num_anchors_each_level.insert(0, 0) + inds_level_interval = np.cumsum(num_anchors_each_level) + pos_level_mask = [] + for i in range(num_level): + mask = (pos_inds >= inds_level_interval[i]) & ( + pos_inds < inds_level_interval[i + 1]) + pos_level_mask.append(mask) + pos_inds_after_paa = [] + ignore_inds_after_paa = [] + for gt_ind in range(num_gt): + pos_inds_gmm = [] + pos_loss_gmm = [] + gt_mask = pos_gt_inds == gt_ind + for level in range(num_level): + level_mask = pos_level_mask[level] + level_gt_mask = level_mask & gt_mask + value, topk_inds = pos_losses[level_gt_mask].topk( + min(level_gt_mask.sum(), self.topk), largest=False) + pos_inds_gmm.append(pos_inds[level_gt_mask][topk_inds]) + pos_loss_gmm.append(value) + pos_inds_gmm = torch.cat(pos_inds_gmm) + pos_loss_gmm = torch.cat(pos_loss_gmm) + # fix gmm need at least two sample + if len(pos_inds_gmm) < 2: + continue + device = pos_inds_gmm.device + pos_loss_gmm, sort_inds = pos_loss_gmm.sort() + pos_inds_gmm = pos_inds_gmm[sort_inds] + pos_loss_gmm = pos_loss_gmm.view(-1, 1).cpu().numpy() + min_loss, max_loss = pos_loss_gmm.min(), pos_loss_gmm.max() + means_init = [[min_loss], [max_loss]] + weights_init = [0.5, 0.5] + precisions_init = [[[1.0]], [[1.0]]] + if skm is None: + raise ImportError('Please run "pip install sklearn" ' + 'to install sklearn first.') + gmm = skm.GaussianMixture( + 2, + weights_init=weights_init, + means_init=means_init, + precisions_init=precisions_init) + gmm.fit(pos_loss_gmm) + gmm_assignment = gmm.predict(pos_loss_gmm) + scores = gmm.score_samples(pos_loss_gmm) + gmm_assignment = torch.from_numpy(gmm_assignment).to(device) + scores = torch.from_numpy(scores).to(device) + + pos_inds_temp, ignore_inds_temp = self.gmm_separation_scheme( + gmm_assignment, scores, pos_inds_gmm) + pos_inds_after_paa.append(pos_inds_temp) + ignore_inds_after_paa.append(ignore_inds_temp) + + pos_inds_after_paa = torch.cat(pos_inds_after_paa) + ignore_inds_after_paa = torch.cat(ignore_inds_after_paa) + reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_paa).all(1) + reassign_ids = pos_inds[reassign_mask] + label[reassign_ids] = self.background_label + label_weight[ignore_inds_after_paa] = 0 + bbox_weight[reassign_ids] = 0 + num_pos = len(pos_inds_after_paa) + return label, label_weight, bbox_weight, num_pos + + def gmm_separation_scheme(self, gmm_assignment, scores, pos_inds_gmm): + """A general separation scheme for gmm model. + + It separates a GMM distribution of candidate samples into three + parts, 0 1 and uncertain areas, and you can implement other + separation schemes by rewriting this function. + + Args: + gmm_assignment (Tensor): The prediction of GMM which is of shape + (num_samples,). The 0/1 value indicates the distribution + that each sample comes from. + scores (Tensor): The probability of sample coming from the + fit GMM distribution. The tensor is of shape (num_samples,). + pos_inds_gmm (Tensor): All the indexes of samples which are used + to fit GMM model. The tensor is of shape (num_samples,) + + Returns: + tuple[Tensor]: The indices of positive and ignored samples. + + - pos_inds_temp (Tensor): Indices of positive samples. + - ignore_inds_temp (Tensor): Indices of ignore samples. + """ + # The implementation is (c) in Fig.3 in origin paper intead of (b). + # You can refer to issues such as + # https://github.com/kkhoot/PAA/issues/8 and + # https://github.com/kkhoot/PAA/issues/9. + fgs = gmm_assignment == 0 + if fgs.nonzero().numel(): + _, pos_thr_ind = scores[fgs].topk(1) + pos_inds_temp = pos_inds_gmm[fgs][:pos_thr_ind + 1] + ignore_inds_temp = pos_inds_gmm.new_tensor([]) + return pos_inds_temp, ignore_inds_temp + + def get_targets( + self, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True, + ): + """Get targets for PAA head. + + This method is almost the same as `AnchorHead.get_targets()`. We direct + return the results from _get_targets_single instead map it to levels + by images_to_levels function. + + Args: + anchor_list (list[list[Tensor]]): Multi level anchors of each + image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, 4). + valid_flag_list (list[list[Tensor]]): Multi level valid flags of + each image. The outer list indicates images, and the inner list + corresponds to feature levels of the image. Each element of + the inner list is a tensor of shape (num_anchors, ) + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be + ignored. + gt_labels_list (list[Tensor]): Ground truth labels of each box. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: Usually returns a tuple containing learning targets. + + - labels (list[Tensor]): Labels of all anchors, each with + shape (num_anchors,). + - label_weights (list[Tensor]): Label weights of all anchor. + each with shape (num_anchors,). + - bbox_targets (list[Tensor]): BBox targets of all anchors. + each with shape (num_anchors, 4). + - bbox_weights (list[Tensor]): BBox weights of all anchors. + each with shape (num_anchors, 4). + - pos_inds (list[Tensor]): Contains all index of positive + sample in all anchor. + - gt_inds (list[Tensor]): Contains all gt_index of positive + sample in all anchor. + """ + + num_imgs = len(img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + concat_anchor_list = [] + concat_valid_flag_list = [] + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + concat_anchor_list.append(torch.cat(anchor_list[i])) + concat_valid_flag_list.append(torch.cat(valid_flag_list[i])) + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + results = multi_apply( + self._get_targets_single, + concat_anchor_list, + concat_valid_flag_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + img_metas, + label_channels=label_channels, + unmap_outputs=unmap_outputs) + + (labels, label_weights, bbox_targets, bbox_weights, valid_pos_inds, + valid_neg_inds, sampling_result) = results + + # Due to valid flag of anchors, we have to calculate the real pos_inds + # in origin anchor set. + pos_inds = [] + for i, single_labels in enumerate(labels): + pos_mask = (0 <= single_labels) & ( + single_labels < self.background_label) + pos_inds.append(pos_mask.nonzero().view(-1)) + + gt_inds = [item.pos_assigned_gt_inds for item in sampling_result] + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + gt_inds) + + def _get_targets_single(self, + flat_anchors, + valid_flags, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + img_meta, + label_channels=1, + unmap_outputs=True): + """Compute regression and classification targets for anchors in a + single image. + + This method is same as `AnchorHead._get_targets_single()`. + """ + assert unmap_outputs, 'We must map outputs back to the original' \ + 'set of anchors in PAAhead' + return super(ATSSHead, self)._get_targets_single( + flat_anchors, + valid_flags, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + img_meta, + label_channels=1, + unmap_outputs=True) + + def _get_bboxes_single(self, + cls_scores, + bbox_preds, + iou_preds, + mlvl_anchors, + img_shape, + scale_factor, + cfg, + rescale=False): + """Transform outputs for a single batch item into labeled boxes. + + This method is almost same as `ATSSHead._get_bboxes_single()`. + We use sqrt(iou_preds * cls_scores) in NMS process instead of just + cls_scores. Besides, score voting is used when `` score_voting`` + is set to True. + """ + assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors) + mlvl_bboxes = [] + mlvl_scores = [] + mlvl_iou_preds = [] + for cls_score, bbox_pred, iou_preds, anchors in zip( + cls_scores, bbox_preds, iou_preds, mlvl_anchors): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + scores = cls_score.permute(1, 2, 0).reshape( + -1, self.cls_out_channels).sigmoid() + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4) + iou_preds = iou_preds.permute(1, 2, 0).reshape(-1).sigmoid() + nms_pre = cfg.get('nms_pre', -1) + if nms_pre > 0 and scores.shape[0] > nms_pre: + max_scores, _ = (scores * iou_preds[:, None]).sqrt().max(dim=1) + _, topk_inds = max_scores.topk(nms_pre) + anchors = anchors[topk_inds, :] + bbox_pred = bbox_pred[topk_inds, :] + scores = scores[topk_inds, :] + iou_preds = iou_preds[topk_inds] + + bboxes = self.bbox_coder.decode( + anchors, bbox_pred, max_shape=img_shape) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + mlvl_iou_preds.append(iou_preds) + + mlvl_bboxes = torch.cat(mlvl_bboxes) + if rescale: + mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor) + mlvl_scores = torch.cat(mlvl_scores) + # Add a dummy background class to the backend when using sigmoid + # remind that we set FG labels to [0, num_class-1] since mmdet v2.0 + # BG cat_id: num_class + padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1) + mlvl_scores = torch.cat([mlvl_scores, padding], dim=1) + mlvl_iou_preds = torch.cat(mlvl_iou_preds) + mlvl_nms_scores = (mlvl_scores * mlvl_iou_preds[:, None]).sqrt() + det_bboxes, det_labels = multiclass_nms( + mlvl_bboxes, + mlvl_nms_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=None) + if self.with_score_voting: + det_bboxes, det_labels = self.score_voting(det_bboxes, det_labels, + mlvl_bboxes, + mlvl_nms_scores, + cfg.score_thr) + + return det_bboxes, det_labels + + def score_voting(self, det_bboxes, det_labels, mlvl_bboxes, + mlvl_nms_scores, score_thr): + """Implementation of score voting method works on each remaining boxes + after NMS procedure. + + Args: + det_bboxes (Tensor): Remaining boxes after NMS procedure, + with shape (k, 5), each dimension means + (x1, y1, x2, y2, score). + det_labels (Tensor): The label of remaining boxes, with shape + (k, 1),Labels are 0-based. + mlvl_bboxes (Tensor): All boxes before the NMS procedure, + with shape (num_anchors,4). + mlvl_nms_scores (Tensor): The scores of all boxes which is used + in the NMS procedure, with shape (num_anchors, num_class) + mlvl_iou_preds (Tensot): The predictions of IOU of all boxes + before the NMS procedure, with shape (num_anchors, 1) + score_thr (float): The score threshold of bboxes. + + Returns: + tuple: Usually returns a tuple containing voting results. + + - det_bboxes_voted (Tensor): Remaining boxes after + score voting procedure, with shape (k, 5), each + dimension means (x1, y1, x2, y2, score). + - det_labels_voted (Tensor): Label of remaining bboxes + after voting, with shape (num_anchors,). + """ + candidate_mask = mlvl_nms_scores > score_thr + candidate_mask_nozeros = candidate_mask.nonzero() + candidate_inds = candidate_mask_nozeros[:, 0] + candidate_labels = candidate_mask_nozeros[:, 1] + candidate_bboxes = mlvl_bboxes[candidate_inds] + candidate_scores = mlvl_nms_scores[candidate_mask] + det_bboxes_voted = [] + det_labels_voted = [] + for cls in range(self.cls_out_channels): + candidate_cls_mask = candidate_labels == cls + if not candidate_cls_mask.any(): + continue + candidate_cls_scores = candidate_scores[candidate_cls_mask] + candidate_cls_bboxes = candidate_bboxes[candidate_cls_mask] + det_cls_mask = det_labels == cls + det_cls_bboxes = det_bboxes[det_cls_mask].view( + -1, det_bboxes.size(-1)) + det_candidate_ious = bbox_overlaps(det_cls_bboxes[:, :4], + candidate_cls_bboxes) + for det_ind in range(len(det_cls_bboxes)): + single_det_ious = det_candidate_ious[det_ind] + pos_ious_mask = single_det_ious > 0.01 + pos_ious = single_det_ious[pos_ious_mask] + pos_bboxes = candidate_cls_bboxes[pos_ious_mask] + pos_scores = candidate_cls_scores[pos_ious_mask] + pis = (torch.exp(-(1 - pos_ious)**2 / 0.025) * + pos_scores)[:, None] + voted_box = torch.sum( + pis * pos_bboxes, dim=0) / torch.sum( + pis, dim=0) + voted_score = det_cls_bboxes[det_ind][-1:][None, :] + det_bboxes_voted.append( + torch.cat((voted_box[None, :], voted_score), dim=1)) + det_labels_voted.append(cls) + + det_bboxes_voted = torch.cat(det_bboxes_voted, dim=0) + det_labels_voted = det_labels.new_tensor(det_labels_voted) + return det_bboxes_voted, det_labels_voted diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 9e335c3e6..1cd8e621f 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -13,6 +13,7 @@ from .htc import HybridTaskCascade from .mask_rcnn import MaskRCNN from .mask_scoring_rcnn import MaskScoringRCNN from .nasfcos import NASFCOS +from .paa import PAA from .point_rend import PointRend from .reppoints_detector import RepPointsDetector from .retinanet import RetinaNet @@ -24,5 +25,5 @@ __all__ = [ 'ATSS', 'BaseDetector', 'SingleStageDetector', 'TwoStageDetector', 'RPN', 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', - 'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet' + 'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA' ] diff --git a/mmdet/models/detectors/paa.py b/mmdet/models/detectors/paa.py new file mode 100644 index 000000000..9b4bb5e09 --- /dev/null +++ b/mmdet/models/detectors/paa.py @@ -0,0 +1,17 @@ +from ..builder import DETECTORS +from .single_stage import SingleStageDetector + + +@DETECTORS.register_module() +class PAA(SingleStageDetector): + """Implementation of `PAA `_.""" + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(PAA, self).__init__(backbone, neck, bbox_head, train_cfg, + test_cfg, pretrained) diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 8d9c5e00d..65aa1a07d 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -1,14 +1,133 @@ import mmcv +import numpy as np import torch from mmdet.core import bbox2roi, build_assigner, build_sampler from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps from mmdet.models.dense_heads import (AnchorHead, CornerHead, FCOSHead, - FSAFHead, GuidedAnchorHead) + FSAFHead, GuidedAnchorHead, PAAHead, + paa_head) +from mmdet.models.dense_heads.paa_head import levels_to_images from mmdet.models.roi_heads.bbox_heads import BBoxHead from mmdet.models.roi_heads.mask_heads import FCNMaskHead, MaskIoUHead +def test_paa_head_loss(): + """Tests paa head loss when truth is empty and non-empty.""" + + class mock_skm(object): + + def GaussianMixture(self, *args, **kwargs): + return self + + def fit(self, loss): + pass + + def predict(self, loss): + components = np.zeros_like(loss, dtype=np.long) + return components.reshape(-1) + + def score_samples(self, loss): + scores = np.random.random(len(loss)) + return scores + + paa_head.skm = mock_skm() + + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3) + }] + train_cfg = mmcv.Config( + dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.1, + neg_iou_thr=0.1, + min_pos_iou=0, + ignore_iof_thr=-1), + allowed_border=-1, + pos_weight=-1, + debug=False)) + # since Focal Loss is not supported on CPU + self = PAAHead( + num_classes=4, + in_channels=1, + train_cfg=train_cfg, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.3), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.5)) + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [4, 8, 16, 32, 64] + ] + self.init_weights() + cls_scores, bbox_preds, iou_preds = self(feat) + # Test that empty ground truth encourages the network to predict background + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + gt_bboxes_ignore = None + empty_gt_losses = self.loss(cls_scores, bbox_preds, iou_preds, gt_bboxes, + gt_labels, img_metas, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_cls_loss = empty_gt_losses['loss_cls'] + empty_box_loss = empty_gt_losses['loss_bbox'] + empty_iou_loss = empty_gt_losses['loss_iou'] + assert empty_cls_loss.item() > 0, 'cls loss should be non-zero' + assert empty_box_loss.item() == 0, ( + 'there should be no box loss when there are no true boxes') + assert empty_iou_loss.item() == 0, ( + 'there should be no box loss when there are no true boxes') + + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + one_gt_losses = self.loss(cls_scores, bbox_preds, iou_preds, gt_bboxes, + gt_labels, img_metas, gt_bboxes_ignore) + onegt_cls_loss = one_gt_losses['loss_cls'] + onegt_box_loss = one_gt_losses['loss_bbox'] + onegt_iou_loss = one_gt_losses['loss_iou'] + assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' + assert onegt_box_loss.item() > 0, 'box loss should be non-zero' + assert onegt_iou_loss.item() > 0, 'box loss should be non-zero' + n, c, h, w = 10, 4, 20, 20 + mlvl_tensor = [torch.ones(n, c, h, w) for i in range(5)] + results = levels_to_images(mlvl_tensor) + assert len(results) == n + assert results[0].size() == (h * w * 5, c) + assert self.with_score_voting + cls_scores = [torch.ones(4, 5, 5)] + bbox_preds = [torch.ones(4, 5, 5)] + iou_preds = [torch.ones(1, 5, 5)] + mlvl_anchors = [torch.ones(5 * 5, 4)] + img_shape = None + scale_factor = [0.5, 0.5] + cfg = mmcv.Config( + dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + rescale = False + self._get_bboxes_single( + cls_scores, + bbox_preds, + iou_preds, + mlvl_anchors, + img_shape, + scale_factor, + cfg, + rescale=rescale) + + def test_fcos_head_loss(): """Tests fcos head loss when truth is empty and non-empty.""" s = 256