diff --git a/configs/yolof/README.md b/configs/yolof/README.md new file mode 100644 index 000000000..4eb9a4b65 --- /dev/null +++ b/configs/yolof/README.md @@ -0,0 +1,25 @@ +# You Only Look One-level Feature + +## Introduction + + + +``` +@inproceedings{chen2021you, + title={You Only Look One-level Feature}, + author={Chen, Qiang and Wang, Yingming and Yang, Tong and Zhang, Xiangyu and Cheng, Jian and Sun, Jian}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + year={2021} +} +``` + +## Results and Models + +| Backbone | Style | Epoch | Lr schd | Mem (GB) | box AP | Config | Download | +|:---------:|:-------:|:-------:|:-------:|:--------:|:------:|:------:|:--------:| +| R-50-C5 | caffe | Y | 1x | 8.3 | 37.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolof/yolof_r50_c5_8x8_1x_coco.py) |[model](http://download.openmmlab.com/mmdetection/v2.0/yolof/yolof_r50_c5_8x8_1x_coco/yolof_r50_c5_8x8_1x_coco_20210425_024427-8e864411.pth) | [log](http://download.openmmlab.com/mmdetection/v2.0/yolof/yolof_r50_c5_8x8_1x_coco/yolof_r50_c5_8x8_1x_coco_20210425_024427.log.json) | + +**Note**: + +1. We find that the performance is unstable and may fluctuate by about 0.3 mAP. mAP 37.4 ~ 37.7 is acceptable in YOLOF_R_50_C5_1x. Such fluctuation can also be found in the [original implementation](https://github.com/chensnathan/YOLOF). +2. In addition to instability issues, sometimes there are large loss fluctuations and NAN, so there may still be problems with this project, which will be improved subsequently. diff --git a/configs/yolof/yolof_r50_c5_8x8_1x_coco.py b/configs/yolof/yolof_r50_c5_8x8_1x_coco.py new file mode 100644 index 000000000..e7b31a1f7 --- /dev/null +++ b/configs/yolof/yolof_r50_c5_8x8_1x_coco.py @@ -0,0 +1,103 @@ +_base_ = [ + '../_base_/datasets/coco_detection.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] +model = dict( + type='YOLOF', + pretrained='open-mmlab://detectron/resnet50_caffe', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(3, ), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='caffe'), + neck=dict( + type='DilatedEncoder', + in_channels=2048, + out_channels=512, + block_mid_channels=128, + num_residual_blocks=4), + bbox_head=dict( + type='YOLOFHead', + num_classes=80, + in_channels=512, + reg_decoded_bbox=True, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + scales=[1, 2, 4, 8, 16], + strides=[32]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1., 1., 1., 1.], + add_ctr_clamp=True, + ctr_clamp=32), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.0)), + # training and testing settings + train_cfg=dict( + assigner=dict( + type='UniformAssigner', pos_ignore_thr=0.15, neg_ignore_thr=0.7), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) +# optimizer +optimizer = dict( + type='SGD', + lr=0.12, + momentum=0.9, + weight_decay=0.0001, + paramwise_cfg=dict( + norm_decay_mult=0., custom_keys={'backbone': dict(lr_mult=1. / 3)})) +lr_config = dict(warmup_iters=1500, warmup_ratio=0.00066667) + +# use caffe img_norm +img_norm_cfg = dict( + mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='RandomShift', shift_ratio=0.5, max_shift_px=32), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 800), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=8, + workers_per_gpu=8, + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/configs/yolof/yolof_r50_c5_8x8_iter-1x_coco.py b/configs/yolof/yolof_r50_c5_8x8_iter-1x_coco.py new file mode 100644 index 000000000..c95c02da1 --- /dev/null +++ b/configs/yolof/yolof_r50_c5_8x8_iter-1x_coco.py @@ -0,0 +1,14 @@ +_base_ = './yolof_r50_c5_8x8_1x_coco.py' + +# We implemented the iter-based config according to the source code. +# COCO dataset has 117266 images after filtering. We use 8 gpu and +# 8 batch size training, so 22500 is equivalent to +# 22500/(117266/(8x8))=12.3 epoch, 15000 is equivalent to 8.2 epoch, +# 20000 is equivalent to 10.9 epoch. Due to lr(0.12) is large, +# the iter-based and epoch-based setting have about 0.2 difference on +# the mAP evaluation value. +lr_config = dict(step=[15000, 20000]) +runner = dict(_delete_=True, type='IterBasedRunner', max_iters=22500) +checkpoint_config = dict(interval=2500) +evaluation = dict(interval=4500) +log_config = dict(interval=20) diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py index 95e34a848..891e6237c 100644 --- a/mmdet/core/bbox/assigners/__init__.py +++ b/mmdet/core/bbox/assigners/__init__.py @@ -8,9 +8,10 @@ from .hungarian_assigner import HungarianAssigner from .max_iou_assigner import MaxIoUAssigner from .point_assigner import PointAssigner from .region_assigner import RegionAssigner +from .uniform_assigner import UniformAssigner __all__ = [ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult', 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner', - 'HungarianAssigner', 'RegionAssigner' + 'HungarianAssigner', 'RegionAssigner', 'UniformAssigner' ] diff --git a/mmdet/core/bbox/assigners/uniform_assigner.py b/mmdet/core/bbox/assigners/uniform_assigner.py new file mode 100644 index 000000000..1d606dee9 --- /dev/null +++ b/mmdet/core/bbox/assigners/uniform_assigner.py @@ -0,0 +1,134 @@ +import torch + +from ..builder import BBOX_ASSIGNERS +from ..iou_calculators import build_iou_calculator +from ..transforms import bbox_xyxy_to_cxcywh +from .assign_result import AssignResult +from .base_assigner import BaseAssigner + + +@BBOX_ASSIGNERS.register_module() +class UniformAssigner(BaseAssigner): + """Uniform Matching between the anchors and gt boxes, which can achieve + balance in positive anchors, and gt_bboxes_ignore was not considered for + now. + + Args: + pos_ignore_thr (float): the threshold to ignore positive anchors + neg_ignore_thr (float): the threshold to ignore negative anchors + match_times(int): Number of positive anchors for each gt box. + Default 4. + iou_calculator (dict): iou_calculator config + """ + + def __init__(self, + pos_ignore_thr, + neg_ignore_thr, + match_times=4, + iou_calculator=dict(type='BboxOverlaps2D')): + self.match_times = match_times + self.pos_ignore_thr = pos_ignore_thr + self.neg_ignore_thr = neg_ignore_thr + self.iou_calculator = build_iou_calculator(iou_calculator) + + def assign(self, + bbox_pred, + anchor, + gt_bboxes, + gt_bboxes_ignore=None, + gt_labels=None): + num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0) + + # 1. assign -1 by default + assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), + 0, + dtype=torch.long) + assigned_labels = bbox_pred.new_full((num_bboxes, ), + -1, + dtype=torch.long) + if num_gts == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + assign_result = AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) + assign_result.set_extra_property( + 'pos_idx', bbox_pred.new_empty(0, dtype=torch.bool)) + assign_result.set_extra_property('pos_predicted_boxes', + bbox_pred.new_empty((0, 4))) + assign_result.set_extra_property('target_boxes', + bbox_pred.new_empty((0, 4))) + return assign_result + + # 2. Compute the L1 cost between boxes + # Note that we use anchors and predict boxes both + cost_bbox = torch.cdist( + bbox_xyxy_to_cxcywh(bbox_pred), + bbox_xyxy_to_cxcywh(gt_bboxes), + p=1) + cost_bbox_anchors = torch.cdist( + bbox_xyxy_to_cxcywh(anchor), bbox_xyxy_to_cxcywh(gt_bboxes), p=1) + + # We found that topk function has different results in cpu and + # cuda mode. In order to ensure consistency with the source code, + # we also use cpu mode. + # TODO: Check whether the performance of cpu and cuda are the same. + C = cost_bbox.cpu() + C1 = cost_bbox_anchors.cpu() + + # self.match_times x n + index = torch.topk( + C, # c=b,n,x c[i]=n,x + k=self.match_times, + dim=0, + largest=False)[1] + + # self.match_times x n + index1 = torch.topk(C1, k=self.match_times, dim=0, largest=False)[1] + # (self.match_times*2) x n + indexes = torch.cat((index, index1), + dim=1).reshape(-1).to(bbox_pred.device) + + pred_overlaps = self.iou_calculator(bbox_pred, gt_bboxes) + anchor_overlaps = self.iou_calculator(anchor, gt_bboxes) + pred_max_overlaps, _ = pred_overlaps.max(dim=1) + anchor_max_overlaps, _ = anchor_overlaps.max(dim=0) + + # 3. Compute the ignore indexes use gt_bboxes and predict boxes + ignore_idx = pred_max_overlaps > self.neg_ignore_thr + assigned_gt_inds[ignore_idx] = -1 + + # 4. Compute the ignore indexes of positive sample use anchors + # and predict boxes + pos_gt_index = torch.arange( + 0, C1.size(1), + device=bbox_pred.device).repeat(self.match_times * 2) + pos_ious = anchor_overlaps[indexes, pos_gt_index] + pos_ignore_idx = pos_ious < self.pos_ignore_thr + + pos_gt_index_with_ignore = pos_gt_index + 1 + pos_gt_index_with_ignore[pos_ignore_idx] = -1 + assigned_gt_inds[indexes] = pos_gt_index_with_ignore + + if gt_labels is not None: + assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1) + pos_inds = torch.nonzero( + assigned_gt_inds > 0, as_tuple=False).squeeze() + if pos_inds.numel() > 0: + assigned_labels[pos_inds] = gt_labels[ + assigned_gt_inds[pos_inds] - 1] + else: + assigned_labels = None + + assign_result = AssignResult( + num_gts, + assigned_gt_inds, + anchor_max_overlaps, + labels=assigned_labels) + assign_result.set_extra_property('pos_idx', ~pos_ignore_idx) + assign_result.set_extra_property('pos_predicted_boxes', + bbox_pred[indexes]) + assign_result.set_extra_property('target_boxes', + gt_bboxes[pos_gt_index]) + return assign_result diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py index 51e17325f..98d30906d 100644 --- a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py +++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py @@ -21,16 +21,25 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder): target for delta coordinates clip_border (bool, optional): Whether clip the objects outside the border of the image. Defaults to True. + add_ctr_clamp (bool): Whether to add center clamp, when added, the + predicted box is clamped is its center is too far away from + the original anchor's center. Only used by YOLOF. Default False. + ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. + Default 32. """ def __init__(self, target_means=(0., 0., 0., 0.), target_stds=(1., 1., 1., 1.), - clip_border=True): + clip_border=True, + add_ctr_clamp=False, + ctr_clamp=32): super(BaseBBoxCoder, self).__init__() self.means = target_means self.stds = target_stds self.clip_border = clip_border + self.add_ctr_clamp = add_ctr_clamp + self.ctr_clamp = ctr_clamp def encode(self, bboxes, gt_bboxes): """Get box regression transformation deltas that can be used to @@ -79,7 +88,8 @@ class DeltaXYWHBBoxCoder(BaseBBoxCoder): if pred_bboxes.ndim == 3: assert pred_bboxes.size(1) == bboxes.size(1) decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds, - max_shape, wh_ratio_clip, self.clip_border) + max_shape, wh_ratio_clip, self.clip_border, + self.add_ctr_clamp, self.ctr_clamp) return decoded_bboxes @@ -137,7 +147,9 @@ def delta2bbox(rois, stds=(1., 1., 1., 1.), max_shape=None, wh_ratio_clip=16 / 1000, - clip_border=True): + clip_border=True, + add_ctr_clamp=False, + ctr_clamp=32): """Apply deltas to shift/scale base boxes. Typically the rois are anchor or proposed bounding boxes and the deltas are @@ -161,6 +173,11 @@ def delta2bbox(rois, wh_ratio_clip (float): Maximum aspect ratio for boxes. clip_border (bool, optional): Whether clip the objects outside the border of the image. Defaults to True. + add_ctr_clamp (bool): Whether to add center clamp, when added, the + predicted box is clamped is its center is too far away from + the original anchor's center. Only used by YOLOF. Default False. + ctr_clamp (int): the maximum pixel shift to clamp. Only used by YOLOF. + Default 32. Returns: Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or @@ -194,9 +211,7 @@ def delta2bbox(rois, dy = denorm_deltas[..., 1::4] dw = denorm_deltas[..., 2::4] dh = denorm_deltas[..., 3::4] - max_ratio = np.abs(np.log(wh_ratio_clip)) - dw = dw.clamp(min=-max_ratio, max=max_ratio) - dh = dh.clamp(min=-max_ratio, max=max_ratio) + x1, y1 = rois[..., 0], rois[..., 1] x2, y2 = rois[..., 2], rois[..., 3] # Compute center of each roi @@ -205,12 +220,25 @@ def delta2bbox(rois, # Compute width/height of each roi pw = (x2 - x1).unsqueeze(-1).expand_as(dw) ph = (y2 - y1).unsqueeze(-1).expand_as(dh) + + dx_width = pw * dx + dy_height = ph * dy + + max_ratio = np.abs(np.log(wh_ratio_clip)) + if add_ctr_clamp: + dx_width = torch.clamp(dx_width, max=ctr_clamp, min=-ctr_clamp) + dy_height = torch.clamp(dy_height, max=ctr_clamp, min=-ctr_clamp) + dw = torch.clamp(dw, max=max_ratio) + dh = torch.clamp(dh, max=max_ratio) + else: + dw = dw.clamp(min=-max_ratio, max=max_ratio) + dh = dh.clamp(min=-max_ratio, max=max_ratio) # Use exp(network energy) to enlarge/shrink each roi gw = pw * dw.exp() gh = ph * dh.exp() # Use network energy to shift the center of each roi - gx = px + pw * dx - gy = py + ph * dy + gx = px + dx_width + gy = py + dy_height # Convert center-xy/width/height to top-left, bottom-right x1 = gx - gw * 0.5 y1 = gy - gh * 0.5 diff --git a/mmdet/datasets/pipelines/__init__.py b/mmdet/datasets/pipelines/__init__.py index c6f424deb..9559969a2 100644 --- a/mmdet/datasets/pipelines/__init__.py +++ b/mmdet/datasets/pipelines/__init__.py @@ -10,7 +10,8 @@ from .loading import (LoadAnnotations, LoadImageFromFile, LoadImageFromWebcam, from .test_time_aug import MultiScaleFlipAug from .transforms import (Albu, CutOut, Expand, MinIoURandomCrop, Normalize, Pad, PhotoMetricDistortion, RandomCenterCropPad, - RandomCrop, RandomFlip, Resize, SegRescale) + RandomCrop, RandomFlip, RandomShift, Resize, + SegRescale) __all__ = [ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', @@ -21,5 +22,5 @@ __all__ = [ 'MinIoURandomCrop', 'Expand', 'PhotoMetricDistortion', 'Albu', 'InstaBoost', 'RandomCenterCropPad', 'AutoAugment', 'CutOut', 'Shear', 'Rotate', 'ColorTransform', 'EqualizeTransform', 'BrightnessTransform', - 'ContrastTransform', 'Translate' + 'ContrastTransform', 'Translate', 'RandomShift' ] diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index caed51d89..c777b31f1 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -472,6 +472,96 @@ class RandomFlip(object): return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})' +@PIPELINES.register_module() +class RandomShift(object): + """Shift the image and box given shift pixels and probability. + + Args: + shift_ratio (float): Probability of shifts. Default 0.5. + max_shift_px (int): The max pixels for shifting. Default 32. + filter_thr_px (int): The width and height threshold for filtering. + The bbox and the rest of the targets below the width and + height threshold will be filtered. Default 1. + """ + + def __init__(self, shift_ratio=0.5, max_shift_px=32, filter_thr_px=1): + assert 0 <= shift_ratio <= 1 + assert max_shift_px >= 0 + self.shift_ratio = shift_ratio + self.max_shift_px = max_shift_px + self.filter_thr_px = int(filter_thr_px) + # The key correspondence from bboxes to labels. + self.bbox2label = { + 'gt_bboxes': 'gt_labels', + 'gt_bboxes_ignore': 'gt_labels_ignore' + } + + def __call__(self, results): + """Call function to random shift images, bounding boxes. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Shift results. + """ + if random.random() < self.shift_ratio: + img_shape = results['img'].shape[:2] + + random_shift_x = random.randint(-self.max_shift_px, + self.max_shift_px) + random_shift_y = random.randint(-self.max_shift_px, + self.max_shift_px) + new_x = max(0, random_shift_x) + orig_x = max(0, -random_shift_x) + new_y = max(0, random_shift_y) + orig_y = max(0, -random_shift_y) + + # TODO: support mask and semantic segmentation maps. + for key in results.get('bbox_fields', []): + bboxes = results[key].copy() + bboxes[..., 0::2] += random_shift_x + bboxes[..., 1::2] += random_shift_y + + # clip border + bboxes[..., 0::2] = np.clip(bboxes[..., 0::2], 0, img_shape[1]) + bboxes[..., 1::2] = np.clip(bboxes[..., 1::2], 0, img_shape[0]) + + # remove invalid bboxes + bbox_w = bboxes[..., 2] - bboxes[..., 0] + bbox_h = bboxes[..., 3] - bboxes[..., 1] + valid_inds = (bbox_w > self.filter_thr_px) & ( + bbox_h > self.filter_thr_px) + # If the shift does not contain any gt-bbox area, skip this + # image. + if key == 'gt_bboxes' and not valid_inds.any(): + return results + bboxes = bboxes[valid_inds] + results[key] = bboxes + + # label fields. e.g. gt_labels and gt_labels_ignore + label_key = self.bbox2label.get(key) + if label_key in results: + results[label_key] = results[label_key][valid_inds] + + for key in results.get('img_fields', ['img']): + img = results[key] + new_img = np.zeros_like(img) + img_h, img_w = img.shape[:2] + new_h = img_h - np.abs(random_shift_y) + new_w = img_w - np.abs(random_shift_x) + new_img[new_y:new_y + new_h, new_x:new_x + new_w] \ + = img[orig_y:orig_y + new_h, orig_x:orig_x + new_w] + results[key] = new_img + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(max_shift_px={self.max_shift_px}, ' + return repr_str + + @PIPELINES.register_module() class Pad(object): """Pad the image & mask. diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index 4999d8e47..3344bbf6d 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -29,6 +29,7 @@ from .ssd_head import SSDHead from .vfnet_head import VFNetHead from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead from .yolo_head import YOLOV3Head +from .yolof_head import YOLOFHead __all__ = [ 'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption', @@ -39,5 +40,5 @@ __all__ = [ 'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead', 'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead', 'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead', - 'AutoAssignHead', 'DETRHead' + 'AutoAssignHead', 'DETRHead', 'YOLOFHead' ] diff --git a/mmdet/models/dense_heads/yolof_head.py b/mmdet/models/dense_heads/yolof_head.py new file mode 100644 index 000000000..e15d4d4a6 --- /dev/null +++ b/mmdet/models/dense_heads/yolof_head.py @@ -0,0 +1,415 @@ +import torch +import torch.nn as nn +from mmcv.cnn import (ConvModule, bias_init_with_prob, constant_init, is_norm, + normal_init) +from mmcv.runner import force_fp32 + +from mmdet.core import anchor_inside_flags, multi_apply, reduce_mean, unmap +from ..builder import HEADS +from .anchor_head import AnchorHead + +INF = 1e8 + + +def levels_to_images(mlvl_tensor): + """Concat multi-level feature maps by image. + + [feature_level0, feature_level1...] -> [feature_image0, feature_image1...] + Convert the shape of each element in mlvl_tensor from (N, C, H, W) to + (N, H*W , C), then split the element to N elements with shape (H*W, C), and + concat elements in same image of all level along first dimension. + + Args: + mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from + corresponding level. Each element is of shape (N, C, H, W) + + Returns: + list[torch.Tensor]: A list that contains N tensors and each tensor is + of shape (num_elements, C) + """ + batch_size = mlvl_tensor[0].size(0) + batch_list = [[] for _ in range(batch_size)] + channels = mlvl_tensor[0].size(1) + for t in mlvl_tensor: + t = t.permute(0, 2, 3, 1) + t = t.view(batch_size, -1, channels).contiguous() + for img in range(batch_size): + batch_list[img].append(t[img]) + return [torch.cat(item, 0) for item in batch_list] + + +@HEADS.register_module() +class YOLOFHead(AnchorHead): + """YOLOFHead Paper link: https://arxiv.org/abs/2103.09460. + + Args: + num_classes (int): The number of object classes (w/o background) + in_channels (List[int]): The number of input channels per scale. + cls_num_convs (int): The number of convolutions of cls branch. + Default 2. + reg_num_convs (int): The number of convolutions of reg branch. + Default 4. + norm_cfg (dict): Dictionary to construct and config norm layer. + """ + + def __init__(self, + num_classes, + in_channels, + num_cls_convs=2, + num_reg_convs=4, + norm_cfg=dict(type='BN', requires_grad=True), + **kwargs): + self.num_cls_convs = num_cls_convs + self.num_reg_convs = num_reg_convs + self.norm_cfg = norm_cfg + super(YOLOFHead, self).__init__(num_classes, in_channels, **kwargs) + + def _init_layers(self): + cls_subnet = [] + bbox_subnet = [] + for i in range(self.num_cls_convs): + cls_subnet.append( + ConvModule( + self.in_channels, + self.in_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg)) + for i in range(self.num_reg_convs): + bbox_subnet.append( + ConvModule( + self.in_channels, + self.in_channels, + kernel_size=3, + padding=1, + norm_cfg=self.norm_cfg)) + self.cls_subnet = nn.Sequential(*cls_subnet) + self.bbox_subnet = nn.Sequential(*bbox_subnet) + self.cls_score = nn.Conv2d( + self.in_channels, + self.num_anchors * self.num_classes, + kernel_size=3, + stride=1, + padding=1) + self.bbox_pred = nn.Conv2d( + self.in_channels, + self.num_anchors * 4, + kernel_size=3, + stride=1, + padding=1) + self.object_pred = nn.Conv2d( + self.in_channels, + self.num_anchors, + kernel_size=3, + stride=1, + padding=1) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, mean=0, std=0.01) + if is_norm(m): + constant_init(m, 1) + + # Use prior in model initialization to improve stability + bias_cls = bias_init_with_prob(0.01) + torch.nn.init.constant_(self.cls_score.bias, bias_cls) + + def forward_single(self, feature): + cls_score = self.cls_score(self.cls_subnet(feature)) + N, _, H, W = cls_score.shape + cls_score = cls_score.view(N, -1, self.num_classes, H, W) + + reg_feat = self.bbox_subnet(feature) + bbox_reg = self.bbox_pred(reg_feat) + objectness = self.object_pred(reg_feat) + + # implicit objectness + objectness = objectness.view(N, -1, 1, H, W) + normalized_cls_score = cls_score + objectness - torch.log( + 1. + torch.clamp(cls_score.exp(), max=INF) + + torch.clamp(objectness.exp(), max=INF)) + normalized_cls_score = normalized_cls_score.view(N, -1, H, W) + return normalized_cls_score, bbox_reg + + @force_fp32(apply_to=('cls_scores', 'bbox_preds')) + def loss(self, + cls_scores, + bbox_preds, + gt_bboxes, + gt_labels, + img_metas, + gt_bboxes_ignore=None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (batch, num_anchors * num_classes, h, w) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (batch, num_anchors * 4, h, w) + gt_bboxes (list[Tensor]): Ground truth bboxes for each image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[Tensor]): class indices corresponding to each box + img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + gt_bboxes_ignore (None | list[Tensor]): specify which bounding + boxes can be ignored when computing the loss. Default: None + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert len(cls_scores) == 1 + assert self.anchor_generator.num_levels == 1 + + device = cls_scores[0].device + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, img_metas, device=device) + + # The output level is always 1 + anchor_list = [anchors[0] for anchors in anchor_list] + valid_flag_list = [valid_flags[0] for valid_flags in valid_flag_list] + + cls_scores_list = levels_to_images(cls_scores) + bbox_preds_list = levels_to_images(bbox_preds) + + label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1 + cls_reg_targets = self.get_targets( + cls_scores_list, + bbox_preds_list, + anchor_list, + valid_flag_list, + gt_bboxes, + img_metas, + gt_bboxes_ignore_list=gt_bboxes_ignore, + gt_labels_list=gt_labels, + label_channels=label_channels) + if cls_reg_targets is None: + return None + (batch_labels, batch_label_weights, num_total_pos, num_total_neg, + batch_bbox_weights, batch_pos_predicted_boxes, + batch_target_boxes) = cls_reg_targets + + flatten_labels = batch_labels.reshape(-1) + batch_label_weights = batch_label_weights.reshape(-1) + cls_score = cls_scores[0].permute(0, 2, 3, + 1).reshape(-1, self.cls_out_channels) + + num_total_samples = (num_total_pos + + num_total_neg) if self.sampling else num_total_pos + num_total_samples = reduce_mean( + cls_score.new_tensor(num_total_samples)).clamp_(1.0).item() + + # classification loss + loss_cls = self.loss_cls( + cls_score, + flatten_labels, + batch_label_weights, + avg_factor=num_total_samples) + + # regression loss + if batch_pos_predicted_boxes.shape[0] == 0: + # no pos sample + loss_bbox = batch_pos_predicted_boxes.sum() * 0 + else: + loss_bbox = self.loss_bbox( + batch_pos_predicted_boxes, + batch_target_boxes, + batch_bbox_weights.float(), + avg_factor=num_total_samples) + + return dict(loss_cls=loss_cls, loss_bbox=loss_bbox) + + def get_targets(self, + cls_scores_list, + bbox_preds_list, + anchor_list, + valid_flag_list, + gt_bboxes_list, + img_metas, + gt_bboxes_ignore_list=None, + gt_labels_list=None, + label_channels=1, + unmap_outputs=True): + """Compute regression and classification targets for anchors in + multiple images. + + Args: + cls_scores_list (list[Tensor]): Classification scores of + each image. each is a 4D-tensor, the shape is + (h * w, num_anchors * num_classes). + bbox_preds_list (list[Tensor]): Bbox preds of each image. + each is a 4D-tensor, the shape is (h * w, num_anchors * 4). + anchor_list (list[Tensor]): Anchors of each image. Each element of + is a tensor of shape (h * w * num_anchors, 4). + valid_flag_list (list[Tensor]): Valid flags of each image. Each + element of is a tensor of shape (h * w * num_anchors, ) + gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image. + img_metas (list[dict]): Meta info of each image. + gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be + ignored. + gt_labels_list (list[Tensor]): Ground truth labels of each box. + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: Usually returns a tuple containing learning targets. + + - batch_labels (Tensor): Label of all images. Each element \ + of is a tensor of shape (batch, h * w * num_anchors) + - batch_label_weights (Tensor): Label weights of all images \ + of is a tensor of shape (batch, h * w * num_anchors) + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + additional_returns: This function enables user-defined returns from + `self._get_targets_single`. These returns are currently refined + to properties at each feature map (i.e. having HxW dimension). + The results will be concatenated after the end + """ + num_imgs = len(img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # compute targets for each image + if gt_bboxes_ignore_list is None: + gt_bboxes_ignore_list = [None for _ in range(num_imgs)] + if gt_labels_list is None: + gt_labels_list = [None for _ in range(num_imgs)] + results = multi_apply( + self._get_targets_single, + bbox_preds_list, + anchor_list, + valid_flag_list, + gt_bboxes_list, + gt_bboxes_ignore_list, + gt_labels_list, + img_metas, + label_channels=label_channels, + unmap_outputs=unmap_outputs) + (all_labels, all_label_weights, pos_inds_list, neg_inds_list, + sampling_results_list) = results[:5] + rest_results = list(results[5:]) # user-added return values + # no valid anchors + if any([labels is None for labels in all_labels]): + return None + # sampled anchors of all images + num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list]) + num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list]) + + batch_labels = torch.stack(all_labels, 0) + batch_label_weights = torch.stack(all_label_weights, 0) + + res = (batch_labels, batch_label_weights, num_total_pos, num_total_neg) + for i, rests in enumerate(rest_results): # user-added return values + rest_results[i] = torch.cat(rests, 0) + + return res + tuple(rest_results) + + def _get_targets_single(self, + bbox_preds, + flat_anchors, + valid_flags, + gt_bboxes, + gt_bboxes_ignore, + gt_labels, + img_meta, + label_channels=1, + unmap_outputs=True): + """Compute regression and classification targets for anchors in a + single image. + + Args: + bbox_preds (Tensor): Bbox prediction of the image, which + shape is (h * w ,4) + flat_anchors (Tensor): Anchors of the image, which shape is + (h * w * num_anchors ,4) + valid_flags (Tensor): Valid flags of the image, which shape is + (h * w * num_anchors,). + gt_bboxes (Tensor): Ground truth bboxes of the image, + shape (num_gts, 4). + gt_bboxes_ignore (Tensor): Ground truth bboxes to be + ignored, shape (num_ignored_gts, 4). + img_meta (dict): Meta info of the image. + gt_labels (Tensor): Ground truth labels of each box, + shape (num_gts,). + label_channels (int): Channel of label. + unmap_outputs (bool): Whether to map outputs back to the original + set of anchors. + + Returns: + tuple: + labels (Tensor): Labels of image, which shape is + (h * w * num_anchors, ). + label_weights (Tensor): Label weights of image, which shape is + (h * w * num_anchors, ). + pos_inds (Tensor): Pos index of image. + neg_inds (Tensor): Neg index of image. + sampling_result (obj:`SamplingResult`): Sampling result. + pos_bbox_weights (Tensor): The Weight of using to calculate + the bbox branch loss, which shape is (num, ). + pos_predicted_boxes (Tensor): boxes predicted value of + using to calculate the bbox branch loss, which shape is + (num, 4). + pos_target_boxes (Tensor): boxes target value of + using to calculate the bbox branch loss, which shape is + (num, 4). + """ + inside_flags = anchor_inside_flags(flat_anchors, valid_flags, + img_meta['img_shape'][:2], + self.train_cfg.allowed_border) + if not inside_flags.any(): + return (None, ) * 8 + # assign gt and sample anchors + anchors = flat_anchors[inside_flags, :] + bbox_preds = bbox_preds.reshape(-1, 4) + bbox_preds = bbox_preds[inside_flags, :] + + # decoded bbox + decoder_bbox_preds = self.bbox_coder.decode(anchors, bbox_preds) + assign_result = self.assigner.assign( + decoder_bbox_preds, anchors, gt_bboxes, gt_bboxes_ignore, + None if self.sampling else gt_labels) + + pos_bbox_weights = assign_result.get_extra_property('pos_idx') + pos_predicted_boxes = assign_result.get_extra_property( + 'pos_predicted_boxes') + pos_target_boxes = assign_result.get_extra_property('target_boxes') + + sampling_result = self.sampler.sample(assign_result, anchors, + gt_bboxes) + num_valid_anchors = anchors.shape[0] + labels = anchors.new_full((num_valid_anchors, ), + self.num_classes, + dtype=torch.long) + label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float) + + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + if len(pos_inds) > 0: + if gt_labels is None: + # Only rpn gives gt_labels as None + # Foreground is the first class since v2.5.0 + labels[pos_inds] = 0 + else: + labels[pos_inds] = gt_labels[ + sampling_result.pos_assigned_gt_inds] + if self.train_cfg.pos_weight <= 0: + label_weights[pos_inds] = 1.0 + else: + label_weights[pos_inds] = self.train_cfg.pos_weight + if len(neg_inds) > 0: + label_weights[neg_inds] = 1.0 + + # map up to original set of anchors + if unmap_outputs: + num_total_anchors = flat_anchors.size(0) + labels = unmap( + labels, num_total_anchors, inside_flags, + fill=self.num_classes) # fill bg label + label_weights = unmap(label_weights, num_total_anchors, + inside_flags) + + return (labels, label_weights, pos_inds, neg_inds, sampling_result, + pos_bbox_weights, pos_predicted_boxes, pos_target_boxes) diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index 12032dda2..907b741ce 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -29,6 +29,7 @@ from .two_stage import TwoStageDetector from .vfnet import VFNet from .yolact import YOLACT from .yolo import YOLOV3 +from .yolof import YOLOF __all__ = [ 'ATSS', 'BaseDetector', 'SingleStageDetector', @@ -37,5 +38,5 @@ __all__ = [ 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', 'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', 'YOLOV3', 'YOLACT', 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', - 'SCNet', 'AutoAssign' + 'SCNet', 'YOLOF', 'AutoAssign' ] diff --git a/mmdet/models/detectors/yolof.py b/mmdet/models/detectors/yolof.py new file mode 100644 index 000000000..dc7b3adfe --- /dev/null +++ b/mmdet/models/detectors/yolof.py @@ -0,0 +1,18 @@ +from ..builder import DETECTORS +from .single_stage import SingleStageDetector + + +@DETECTORS.register_module() +class YOLOF(SingleStageDetector): + r"""Implementation of `You Only Look One-level Feature + `_""" + + def __init__(self, + backbone, + neck, + bbox_head, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(YOLOF, self).__init__(backbone, neck, bbox_head, train_cfg, + test_cfg, pretrained) diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py index 02f833a8a..5bf955802 100644 --- a/mmdet/models/necks/__init__.py +++ b/mmdet/models/necks/__init__.py @@ -1,5 +1,6 @@ from .bfp import BFP from .channel_mapper import ChannelMapper +from .dilated_encoder import DilatedEncoder from .fpg import FPG from .fpn import FPN from .fpn_carafe import FPN_CARAFE @@ -12,5 +13,5 @@ from .yolo_neck import YOLOV3Neck __all__ = [ 'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN', - 'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG' + 'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG', 'DilatedEncoder' ] diff --git a/mmdet/models/necks/dilated_encoder.py b/mmdet/models/necks/dilated_encoder.py new file mode 100644 index 000000000..e97d5ccc7 --- /dev/null +++ b/mmdet/models/necks/dilated_encoder.py @@ -0,0 +1,107 @@ +import torch.nn as nn +from mmcv.cnn import (ConvModule, caffe2_xavier_init, constant_init, is_norm, + normal_init) +from torch.nn import BatchNorm2d + +from ..builder import NECKS + + +class Bottleneck(nn.Module): + """Bottleneck block for DilatedEncoder used in `YOLOF. + + `. + + The Bottleneck contains three ConvLayers and one residual connection. + + Args: + in_channels (int): The number of input channels. + mid_channels (int): The number of middle output channels. + dilation (int): Dilation rate. + norm_cfg (dict): Dictionary to construct and config norm layer. + """ + + def __init__(self, + in_channels, + mid_channels, + dilation, + norm_cfg=dict(type='BN', requires_grad=True)): + super(Bottleneck, self).__init__() + self.conv1 = ConvModule( + in_channels, mid_channels, 1, norm_cfg=norm_cfg) + self.conv2 = ConvModule( + mid_channels, + mid_channels, + 3, + padding=dilation, + dilation=dilation, + norm_cfg=norm_cfg) + self.conv3 = ConvModule( + mid_channels, in_channels, 1, norm_cfg=norm_cfg) + + def forward(self, x): + identity = x + out = self.conv1(x) + out = self.conv2(out) + out = self.conv3(out) + out = out + identity + return out + + +@NECKS.register_module() +class DilatedEncoder(nn.Module): + """Dilated Encoder for YOLOF `. + + This module contains two types of components: + - the original FPN lateral convolution layer and fpn convolution layer, + which are 1x1 conv + 3x3 conv + - the dilated residual block + + Args: + in_channels (int): The number of input channels. + out_channels (int): The number of output channels. + block_mid_channels (int): The number of middle block output channels + num_residual_blocks (int): The number of residual blocks. + """ + + def __init__(self, in_channels, out_channels, block_mid_channels, + num_residual_blocks): + super(DilatedEncoder, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.block_mid_channels = block_mid_channels + self.num_residual_blocks = num_residual_blocks + self.block_dilations = [2, 4, 6, 8] + self._init_layers() + + def _init_layers(self): + self.lateral_conv = nn.Conv2d( + self.in_channels, self.out_channels, kernel_size=1) + self.lateral_norm = BatchNorm2d(self.out_channels) + self.fpn_conv = nn.Conv2d( + self.out_channels, self.out_channels, kernel_size=3, padding=1) + self.fpn_norm = BatchNorm2d(self.out_channels) + encoder_blocks = [] + for i in range(self.num_residual_blocks): + dilation = self.block_dilations[i] + encoder_blocks.append( + Bottleneck( + self.out_channels, + self.block_mid_channels, + dilation=dilation)) + self.dilated_encoder_blocks = nn.Sequential(*encoder_blocks) + + def init_weights(self): + caffe2_xavier_init(self.lateral_conv) + caffe2_xavier_init(self.fpn_conv) + for m in [self.lateral_norm, self.fpn_norm]: + constant_init(m, 1) + for m in self.dilated_encoder_blocks.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, mean=0, std=0.01) + if is_norm(m): + constant_init(m, 1) + + def forward(self, feature): + out = self.lateral_norm(self.lateral_conv(feature[-1])) + out = self.fpn_norm(self.fpn_conv(out)) + return self.dilated_encoder_blocks(out), diff --git a/tests/test_data/test_pipelines/test_transform/test_transform.py b/tests/test_data/test_pipelines/test_transform/test_transform.py index 85ddbbc72..b69d5ef9f 100644 --- a/tests/test_data/test_pipelines/test_transform/test_transform.py +++ b/tests/test_data/test_pipelines/test_transform/test_transform.py @@ -750,3 +750,43 @@ def test_cutout(): cutout_module = build_from_cfg(transform, PIPELINES) cutout_result = cutout_module(copy.deepcopy(results)) assert cutout_result['img'].sum() > img.sum() + + +def test_random_shift(): + # test assertion for invalid shift_ratio + with pytest.raises(AssertionError): + transform = dict(type='RandomShift', shift_ratio=1.5) + build_from_cfg(transform, PIPELINES) + + # test assertion for invalid max_shift_px + with pytest.raises(AssertionError): + transform = dict(type='RandomShift', max_shift_px=-1) + build_from_cfg(transform, PIPELINES) + + results = dict() + img = mmcv.imread( + osp.join(osp.dirname(__file__), '../../../data/color.jpg'), 'color') + results['img'] = img + # TODO: add img_fields test + results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore'] + + def create_random_bboxes(num_bboxes, img_w, img_h): + bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2)) + bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2)) + bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1) + bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype( + np.int) + return bboxes + + h, w, _ = img.shape + gt_bboxes = create_random_bboxes(8, w, h) + gt_bboxes_ignore = create_random_bboxes(2, w, h) + results['gt_labels'] = torch.ones(gt_bboxes.shape[0]) + results['gt_bboxes'] = gt_bboxes + results['gt_bboxes_ignore'] = gt_bboxes_ignore + transform = dict(type='RandomShift', shift_ratio=1.0) + random_shift_module = build_from_cfg(transform, PIPELINES) + results = random_shift_module(results) + + assert results['img'].shape[:2] == (h, w) + assert results['gt_labels'].shape[0] == results['gt_bboxes'].shape[0] diff --git a/tests/test_models/test_dense_heads/test_yolof_head.py b/tests/test_models/test_dense_heads/test_yolof_head.py new file mode 100644 index 000000000..ef21b66cf --- /dev/null +++ b/tests/test_models/test_dense_heads/test_yolof_head.py @@ -0,0 +1,75 @@ +import mmcv +import torch + +from mmdet.models.dense_heads import YOLOFHead + + +def test_yolof_head_loss(): + """Tests yolof head loss when truth is empty and non-empty.""" + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3) + }] + train_cfg = mmcv.Config( + dict( + assigner=dict( + type='UniformAssigner', + pos_ignore_thr=0.15, + neg_ignore_thr=0.7), + allowed_border=-1, + pos_weight=-1, + debug=False)) + self = YOLOFHead( + num_classes=4, + in_channels=1, + reg_decoded_bbox=True, + train_cfg=train_cfg, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + scales=[1, 2, 4, 8, 16], + strides=[32]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1., 1., 1., 1.], + add_ctr_clamp=True, + ctr_clamp=32), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=1.0)) + feat = [torch.rand(1, 1, s // 32, s // 32)] + cls_scores, bbox_preds = self.forward(feat) + + # Test that empty ground truth encourages the network to predict background + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + gt_bboxes_ignore = None + empty_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_cls_loss = empty_gt_losses['loss_cls'] + empty_box_loss = empty_gt_losses['loss_bbox'] + assert empty_cls_loss.item() > 0, 'cls loss should be non-zero' + assert empty_box_loss.item() == 0, ( + 'there should be no box loss when there are no true boxes') + + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + one_gt_losses = self.loss(cls_scores, bbox_preds, gt_bboxes, gt_labels, + img_metas, gt_bboxes_ignore) + onegt_cls_loss = one_gt_losses['loss_cls'] + onegt_box_loss = one_gt_losses['loss_bbox'] + assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' + assert onegt_box_loss.item() > 0, 'box loss should be non-zero' diff --git a/tests/test_models/test_necks.py b/tests/test_models/test_necks.py index 56885477f..312c5917f 100644 --- a/tests/test_models/test_necks.py +++ b/tests/test_models/test_necks.py @@ -2,7 +2,7 @@ import pytest import torch from torch.nn.modules.batchnorm import _BatchNorm -from mmdet.models.necks import FPN, ChannelMapper +from mmdet.models.necks import FPN, ChannelMapper, DilatedEncoder def test_fpn(): @@ -236,3 +236,13 @@ def test_channel_mapper(): for i in range(len(feats)): outs[i].shape[1] == out_channels outs[i].shape[2] == outs[i].shape[3] == s // (2**i) + + +def test_dilated_encoder(): + in_channels = 16 + out_channels = 32 + out_shape = 34 + dilated_encoder = DilatedEncoder(in_channels, out_channels, 16, 2) + feat = [torch.rand(1, in_channels, 34, 34)] + out_feat = dilated_encoder(feat)[0] + assert out_feat.shape == (1, out_channels, out_shape, out_shape) diff --git a/tests/test_utils/test_assigner.py b/tests/test_utils/test_assigner.py index 7971662dc..949234b6f 100644 --- a/tests/test_utils/test_assigner.py +++ b/tests/test_utils/test_assigner.py @@ -8,7 +8,8 @@ import torch from mmdet.core.bbox.assigners import (ApproxMaxIoUAssigner, CenterRegionAssigner, HungarianAssigner, - MaxIoUAssigner, PointAssigner) + MaxIoUAssigner, PointAssigner, + UniformAssigner) def test_max_iou_assigner(): @@ -422,3 +423,75 @@ def test_hungarian_match_assigner(): assert torch.all(assign_result.gt_inds > -1) assert (assign_result.gt_inds > 0).sum() == gt_bboxes.size(0) assert (assign_result.labels > -1).sum() == gt_bboxes.size(0) + + +def test_uniform_assigner(): + self = UniformAssigner(0.15, 0.7, 1) + pred_bbox = torch.FloatTensor([ + [1, 1, 12, 8], + [4, 4, 20, 20], + [1, 5, 15, 15], + [30, 5, 32, 42], + ]) + anchor = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + gt_labels = torch.LongTensor([2, 3]) + assign_result = self.assign( + pred_bbox, anchor, gt_bboxes, gt_labels=gt_labels) + assert len(assign_result.gt_inds) == 4 + assert len(assign_result.labels) == 4 + + expected_gt_inds = torch.LongTensor([-1, 0, 2, 0]) + assert torch.all(assign_result.gt_inds == expected_gt_inds) + + +def test_uniform_assigner_with_empty_gt(): + """Test corner case where an image might have no true detections.""" + self = UniformAssigner(0.15, 0.7, 1) + pred_bbox = torch.FloatTensor([ + [1, 1, 12, 8], + [4, 4, 20, 20], + [1, 5, 15, 15], + [30, 5, 32, 42], + ]) + anchor = torch.FloatTensor([ + [0, 0, 10, 10], + [10, 10, 20, 20], + [5, 5, 15, 15], + [32, 32, 38, 42], + ]) + gt_bboxes = torch.empty(0, 4) + assign_result = self.assign(pred_bbox, anchor, gt_bboxes) + + expected_gt_inds = torch.LongTensor([0, 0, 0, 0]) + assert torch.all(assign_result.gt_inds == expected_gt_inds) + + +def test_uniform_assigner_with_empty_boxes(): + """Test corner case where a network might predict no boxes.""" + self = UniformAssigner(0.15, 0.7, 1) + pred_bbox = torch.empty((0, 4)) + anchor = torch.empty((0, 4)) + gt_bboxes = torch.FloatTensor([ + [0, 0, 10, 9], + [0, 10, 10, 19], + ]) + gt_labels = torch.LongTensor([2, 3]) + + # Test with gt_labels + assign_result = self.assign( + pred_bbox, anchor, gt_bboxes, gt_labels=gt_labels) + assert len(assign_result.gt_inds) == 0 + assert tuple(assign_result.labels.shape) == (0, ) + + # Test without gt_labels + assign_result = self.assign(pred_bbox, anchor, gt_bboxes, gt_labels=None) + assert len(assign_result.gt_inds) == 0 diff --git a/tests/test_utils/test_coder.py b/tests/test_utils/test_coder.py index c0bfc0f7d..2dca41319 100644 --- a/tests/test_utils/test_coder.py +++ b/tests/test_utils/test_coder.py @@ -58,6 +58,21 @@ def test_delta_bbox_coder(): out = coder.decode(rois, deltas, max_shape=(32, 32)) assert rois.shape == out.shape + # test add_ctr_clamp + coder = DeltaXYWHBBoxCoder(add_ctr_clamp=True, ctr_clamp=2) + + rois = torch.Tensor([[0., 0., 6., 6.], [0., 0., 1., 1.], [0., 0., 1., 1.], + [5., 5., 5., 5.]]) + deltas = torch.Tensor([[1., 1., 2., 2.], [1., 1., 1., 1.], + [0., 0., 2., -1.], [0.7, -1.9, -0.5, 0.3]]) + expected_decode_bboxes = torch.Tensor([[0.0000, 0.0000, 27.1672, 27.1672], + [0.1409, 0.1409, 2.8591, 2.8591], + [0.0000, 0.3161, 4.1945, 0.6839], + [5.0000, 5.0000, 5.0000, 5.0000]]) + + out = coder.decode(rois, deltas, max_shape=(32, 32)) + assert expected_decode_bboxes.allclose(out, atol=1e-04) + def test_tblr_bbox_coder(): coder = TBLRBBoxCoder(normalizer=15.)