From 2294badd86b0bc83e49692187a4639b69f2ec4b8 Mon Sep 17 00:00:00 2001 From: BigDong Date: Tue, 28 Sep 2021 16:39:18 +0800 Subject: [PATCH] [Feature]SOLO: Segmenting Objects by Locations (#5832) * add SOLO * add decoupled SOLO * update decoupled SOLO * fix linting errors * format config filename, config content, loss names, norm_cfg * fix linting errors * fix matrix_nms and configs * Add unit tests for SOLO head * add diceloss * support mmdet-v2+ * add decopledhead * clean Chinese comments * update SOLO * fix * delet debug files * update solo config * fix bug * [Fix]: fix some params cannot get grad * [fix] make sure params can get grad * init commit for resutls * add results and instance results * add docstr * add more unitets * add more unitets * add more unitets * add more unintest * add unitet for instance results * add example * add meta_info_keys results_keys * add modified from * fix unitets * fix typo * add instance seg releated base * forward train for solo * fix simpletest * add docstr * convert to tensor at begin * refactor yolact traing * refactor yolact test * fix test of yolact * fix empty det of yolact * fix return tuple * add format_results * add testfor formatr * solo * add unitest for format_results * add unitest * solo * remove yolact relatede modification * fix zero bbox * fix score size * fix desolo head * update solo head * fix error * rename some attribute * rename some attribute * rename decouple * add doc * format loss * reconer decople * add doc * fix test * fix test * fix doc * remove points nms * refactor the post process * refactor post process of decaouple * refactor base * refactor get_target single * refactor the training of decouple * refactor test of decouple * refactor dice loss * refactor dice * change to format a dict * support detection results in test.py * add base one-stage segmentor * fix doc * add onnx export * add solo config * add dice loss test unit * add solo_head test unit * add more detailed comments * resolve commnets * add test unit * update docstrings and move center of mass to core.utils * add center of mass test unit * resolve comments * resolve commets * fix rle encode * fix results * fix results * abstract dice loss * update docstring * add EPS * add center of mass test unit * add eps parameter * add vis * add nms test unit * configs/ add configs * add desolo light config file * support desolo light head * add desolo light config * add matrix_nms test unit * fix matrix_nms test unit * update matrix doc string * fix error * fix logic error * fix logic error * add comment in test unit * move has_acted to initialization * update solo readme * rename * revert test * fix import in example * fix unitest * add more uintest * add more unites * add more unitest * rename meta to meta_info * fix docstr * fix foc * fix doc * add format_results * fix format results * fix some default value and function name * fix desolo light head error * fix doc and move isntancedata to a new file * fix typo * fix unitest in torch 13 * update matrix nms docstring * fix hard code * add vis * add vis * fix lint * fix doc * fix doc * fix vis * fix vis * fix vis * fix forwardummy doc * fix doc * fix comment * fix doc * fix order of argument * add base one-stage segmentor * fix config files * fix doc * fix doc * support solo * fix error * support solo * rename cls_score * support solo * update model zoo * update docstring * update docstring Co-authored-by: WXinlong Co-authored-by: zhangshilong <2392587229zsl@gmail.com> --- configs/solo/README.md | 42 + .../decoupled_solo_light_r50_fpn_3x_coco.py | 63 + .../solo/decoupled_solo_r50_fpn_1x_coco.py | 28 + .../solo/decoupled_solo_r50_fpn_3x_coco.py | 25 + configs/solo/metafile.yml | 115 ++ configs/solo/solo_r50_fpn_1x_coco.py | 53 + configs/solo/solo_r50_fpn_3x_coco.py | 28 + docs/model_zoo.md | 8 + mmdet/core/post_processing/__init__.py | 3 +- mmdet/core/post_processing/matrix_nms.py | 121 ++ mmdet/core/utils/__init__.py | 6 +- mmdet/core/utils/misc.py | 43 + mmdet/models/dense_heads/__init__.py | 4 +- mmdet/models/dense_heads/solo_head.py | 1177 +++++++++++++++++ mmdet/models/detectors/__init__.py | 3 +- mmdet/models/detectors/solo.py | 29 + mmdet/models/losses/__init__.py | 3 +- mmdet/models/losses/dice_loss.py | 123 ++ model-index.yml | 2 + .../test_dense_heads/test_solo_head.py | 284 ++++ tests/test_models/test_loss.py | 53 +- tests/test_utils/test_misc.py | 19 +- tests/test_utils/test_nms.py | 75 ++ 23 files changed, 2299 insertions(+), 8 deletions(-) create mode 100644 configs/solo/README.md create mode 100644 configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py create mode 100644 configs/solo/decoupled_solo_r50_fpn_1x_coco.py create mode 100644 configs/solo/decoupled_solo_r50_fpn_3x_coco.py create mode 100644 configs/solo/metafile.yml create mode 100644 configs/solo/solo_r50_fpn_1x_coco.py create mode 100644 configs/solo/solo_r50_fpn_3x_coco.py create mode 100644 mmdet/core/post_processing/matrix_nms.py create mode 100644 mmdet/models/dense_heads/solo_head.py create mode 100644 mmdet/models/detectors/solo.py create mode 100644 mmdet/models/losses/dice_loss.py create mode 100644 tests/test_models/test_dense_heads/test_solo_head.py create mode 100644 tests/test_utils/test_nms.py diff --git a/configs/solo/README.md b/configs/solo/README.md new file mode 100644 index 000000000..709e246f6 --- /dev/null +++ b/configs/solo/README.md @@ -0,0 +1,42 @@ +# SOLO: Segmenting Objects by Locations + +## Introduction + +``` +@inproceedings{wang2020solo, + title = {{SOLO}: Segmenting Objects by Locations}, + author = {Wang, Xinlong and Kong, Tao and Shen, Chunhua and Jiang, Yuning and Li, Lei}, + booktitle = {Proc. Eur. Conf. Computer Vision (ECCV)}, + year = {2020} +} +``` + +## Results and Models + +### SOLO + +| Backbone | Style | MS train | Lr schd | Mem (GB) | Inf time (fps) | mask AP | Download | +|:---------:|:-------:|:--------:|:-------:|:--------:|:--------------:|:------:|:--------:| +| R-50 | pytorch | N | 1x | 8.0 | 14.0 | 33.1 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_1x_coco/solo_r50_fpn_1x_coco_20210821_035055-2290a6b8.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_1x_coco/solo_r50_fpn_1x_coco_20210821_035055.log.json) | +| R-50 | pytorch | Y | 3x | 7.4 | 14.0 | 35.9 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_3x_coco/solo_r50_fpn_3x_coco_20210901_012353-11d224d7.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_3x_coco/solo_r50_fpn_3x_coco_20210901_012353.log.json) | + +### Decoupled SOLO + +| Backbone | Style | MS train | Lr schd | Mem (GB) | Inf time (fps) | mask AP | Download | +|:---------:|:-------:|:--------:|:-------:|:--------:|:--------------:|:-------:|:--------:| +| R-50 | pytorch | N | 1x | 7.8 | 12.5 | 33.9 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_1x_coco/decoupled_solo_r50_fpn_1x_coco_20210820_233348-6337c589.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_1x_coco/decoupled_solo_r50_fpn_1x_coco_20210820_233348.log.json) | +| R-50 | pytorch | Y | 3x | 7.9 | 12.5 | 36.7 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_3x_coco/decoupled_solo_r50_fpn_3x_coco_20210821_042504-7b3301ec.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_3x_coco/decoupled_solo_r50_fpn_3x_coco_20210821_042504.log.json) | + +- Decoupled SOLO has a decoupled head which is different from SOLO head. +Decoupled SOLO serves as an efficient and equivalent variant in accuracy +of SOLO. Please refer to the corresponding config files for details. + +### Decoupled Light SOLO + +| Backbone | Style | MS train | Lr schd | Mem (GB) | Inf time (fps) | mask AP | Download | +|:---------:|:-------:|:--------:|:-------:|:--------:|:--------------:|:------:|:--------:| +| R-50 | pytorch | Y | 3x | 2.2 | 31.2 | 32.9 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_light_r50_fpn_3x_coco/decoupled_solo_light_r50_fpn_3x_coco_20210906_142703-e70e226f.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_light_r50_fpn_3x_coco/decoupled_solo_light_r50_fpn_3x_coco_20210906_142703.log.json) | + +- Decoupled Light SOLO using decoupled structure similar to Decoupled +SOLO head, with light-weight head and smaller input size, Please refer +to the corresponding config files for details. diff --git a/configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py b/configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py new file mode 100644 index 000000000..101f8f1d3 --- /dev/null +++ b/configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py @@ -0,0 +1,63 @@ +_base_ = './decoupled_solo_r50_fpn_3x_coco.py' + +# model settings +model = dict( + mask_head=dict( + type='DecoupledSOLOLightHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + strides=[8, 8, 16, 32, 32], + scale_ranges=((1, 64), (32, 128), (64, 256), (128, 512), (256, 2048)), + pos_scale=0.2, + num_grids=[40, 36, 24, 16, 12], + cls_down_index=0, + loss_mask=dict( + type='DiceLoss', use_sigmoid=True, activate=False, + loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True))) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='Resize', + img_scale=[(852, 512), (852, 480), (852, 448), (852, 416), (852, 384), + (852, 352)], + multiscale_mode='value', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(852, 512), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/configs/solo/decoupled_solo_r50_fpn_1x_coco.py b/configs/solo/decoupled_solo_r50_fpn_1x_coco.py new file mode 100644 index 000000000..b611cdf4d --- /dev/null +++ b/configs/solo/decoupled_solo_r50_fpn_1x_coco.py @@ -0,0 +1,28 @@ +_base_ = [ + './solo_r50_fpn_1x_coco.py', +] +# model settings +model = dict( + mask_head=dict( + type='DecoupledSOLOHead', + num_classes=80, + in_channels=256, + stacked_convs=7, + feat_channels=256, + strides=[8, 8, 16, 32, 32], + scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)), + pos_scale=0.2, + num_grids=[40, 36, 24, 16, 12], + cls_down_index=0, + loss_mask=dict( + type='DiceLoss', use_sigmoid=True, activate=False, + loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True))) + +optimizer = dict(type='SGD', lr=0.01) diff --git a/configs/solo/decoupled_solo_r50_fpn_3x_coco.py b/configs/solo/decoupled_solo_r50_fpn_3x_coco.py new file mode 100644 index 000000000..4a8c19dec --- /dev/null +++ b/configs/solo/decoupled_solo_r50_fpn_3x_coco.py @@ -0,0 +1,25 @@ +_base_ = './solo_r50_fpn_3x_coco.py' + +# model settings +model = dict( + mask_head=dict( + type='DecoupledSOLOHead', + num_classes=80, + in_channels=256, + stacked_convs=7, + feat_channels=256, + strides=[8, 8, 16, 32, 32], + scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)), + pos_scale=0.2, + num_grids=[40, 36, 24, 16, 12], + cls_down_index=0, + loss_mask=dict( + type='DiceLoss', use_sigmoid=True, activate=False, + loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True))) diff --git a/configs/solo/metafile.yml b/configs/solo/metafile.yml new file mode 100644 index 000000000..b6244e80f --- /dev/null +++ b/configs/solo/metafile.yml @@ -0,0 +1,115 @@ +Collections: + - Name: SOLO + Metadata: + Training Data: COCO + Training Techniques: + - SGD with Momentum + - Weight Decay + Training Resources: 8x V100 GPUs + Architecture: + - FPN + - Convolution + - ResNet + Paper: https://arxiv.org/abs/1912.04488 + README: configs/solo/README.md + +Models: + - Name: decoupled_solo_r50_fpn_1x_coco + In Collection: SOLO + Config: configs/solo/decoupled_solo_r50_fpn_1x_coco.py + Metadata: + Training Memory (GB): 7.8 + Epochs: 12 + inference time (ms/im): + - value: 116.4 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (1333, 800) + Results: + - Task: Instance Segmentation + Dataset: COCO + Metrics: + mask AP: 33.9 + Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_1x_coco/decoupled_solo_r50_fpn_1x_coco_20210820_233348-6337c589.pth + + - Name: decoupled_solo_r50_fpn_3x_coco + In Collection: SOLO + Config: configs/solo/decoupled_solo_r50_fpn_3x_coco.py + Metadata: + Training Memory (GB): 7.9 + Epochs: 36 + inference time (ms/im): + - value: 117.2 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (1333, 800) + Results: + - Task: Instance Segmentation + Dataset: COCO + Metrics: + mask AP: 36.7 + Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_3x_coco/decoupled_solo_r50_fpn_3x_coco_20210821_042504-7b3301ec.pth + + - Name: decoupled_solo_light_r50_fpn_3x_coco + In Collection: SOLO + Config: configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py + Metadata: + Training Memory (GB): 2.2 + Epochs: 36 + inference time (ms/im): + - value: 35.0 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (852, 512) + Results: + - Task: Instance Segmentation + Dataset: COCO + Metrics: + mask AP: 32.9 + Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_light_r50_fpn_3x_coco/decoupled_solo_light_r50_fpn_3x_coco_20210906_142703-e70e226f.pth + + - Name: solo_r50_fpn_3x_coco + In Collection: SOLO + Config: configs/solo/solo_r50_fpn_3x_coco.py + Metadata: + Training Memory (GB): 7.4 + Epochs: 36 + inference time (ms/im): + - value: 94.2 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (1333, 800) + Results: + - Task: Instance Segmentation + Dataset: COCO + Metrics: + mask AP: 35.9 + Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_3x_coco/solo_r50_fpn_3x_coco_20210901_012353-11d224d7.pth + + - Name: solo_r50_fpn_1x_coco + In Collection: SOLO + Config: configs/solo/solo_r50_fpn_1x_coco.py + Metadata: + Training Memory (GB): 8.0 + Epochs: 12 + inference time (ms/im): + - value: 95.1 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (1333, 800) + Results: + - Task: Instance Segmentation + Dataset: COCO + Metrics: + mask AP: 33.1 + Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_1x_coco/solo_r50_fpn_1x_coco_20210821_035055-2290a6b8.pth diff --git a/configs/solo/solo_r50_fpn_1x_coco.py b/configs/solo/solo_r50_fpn_1x_coco.py new file mode 100644 index 000000000..9093a5048 --- /dev/null +++ b/configs/solo/solo_r50_fpn_1x_coco.py @@ -0,0 +1,53 @@ +_base_ = [ + '../_base_/datasets/coco_instance.py', + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' +] + +# model settings +model = dict( + type='SOLO', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), + style='pytorch'), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=0, + num_outs=5), + mask_head=dict( + type='SOLOHead', + num_classes=80, + in_channels=256, + stacked_convs=7, + feat_channels=256, + strides=[8, 8, 16, 32, 32], + scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)), + pos_scale=0.2, + num_grids=[40, 36, 24, 16, 12], + cls_down_index=0, + loss_mask=dict(type='DiceLoss', use_sigmoid=True, loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)), + # model training and testing settings + test_cfg=dict( + nms_pre=500, + score_thr=0.1, + mask_thr=0.5, + filter_thr=0.05, + kernel='gaussian', # gaussian/linear + sigma=2.0, + max_per_img=100)) + +# optimizer +optimizer = dict(type='SGD', lr=0.01) diff --git a/configs/solo/solo_r50_fpn_3x_coco.py b/configs/solo/solo_r50_fpn_3x_coco.py new file mode 100644 index 000000000..52302cdf9 --- /dev/null +++ b/configs/solo/solo_r50_fpn_3x_coco.py @@ -0,0 +1,28 @@ +_base_ = './solo_r50_fpn_1x_coco.py' + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='Resize', + img_scale=[(1333, 800), (1333, 768), (1333, 736), (1333, 704), + (1333, 672), (1333, 640)], + multiscale_mode='value', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), +] +data = dict(train=dict(pipeline=train_pipeline)) + +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[27, 33]) +runner = dict(type='EpochBasedRunner', max_epochs=36) diff --git a/docs/model_zoo.md b/docs/model_zoo.md index ea9fd0e27..de3ca8c81 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -230,6 +230,14 @@ Please refer to [CenterNet](https://github.com/open-mmlab/mmdetection/blob/maste Please refer to [YOLOX](https://github.com/open-mmlab/mmdetection/blob/master/configs/yolox) for details. +### PVT + +Please refer to [PVT](https://github.com/open-mmlab/mmdetection/blob/master/configs/pvt) for details. + +### SOLO + +Please refer to [SOLO](https://github.com/open-mmlab/mmdetection/blob/master/configs/solo) for details. + ### Other datasets We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face). diff --git a/mmdet/core/post_processing/__init__.py b/mmdet/core/post_processing/__init__.py index bcb63497f..00376bd49 100644 --- a/mmdet/core/post_processing/__init__.py +++ b/mmdet/core/post_processing/__init__.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .bbox_nms import fast_nms, multiclass_nms +from .matrix_nms import mask_matrix_nms from .merge_augs import (merge_aug_bboxes, merge_aug_masks, merge_aug_proposals, merge_aug_scores) __all__ = [ 'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes', - 'merge_aug_scores', 'merge_aug_masks', 'fast_nms' + 'merge_aug_scores', 'merge_aug_masks', 'mask_matrix_nms', 'fast_nms' ] diff --git a/mmdet/core/post_processing/matrix_nms.py b/mmdet/core/post_processing/matrix_nms.py new file mode 100644 index 000000000..e2fc5d9e2 --- /dev/null +++ b/mmdet/core/post_processing/matrix_nms.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +def mask_matrix_nms(masks, + labels, + scores, + filter_thr=-1, + nms_pre=-1, + max_num=-1, + kernel='gaussian', + sigma=2.0, + mask_area=None): + """Matrix NMS for multi-class masks. + + Args: + masks (Tensor): Has shape (num_instances, h, w) + labels (Tensor): Labels of corresponding masks, + has shape (num_instances,). + scores (Tensor): Mask scores of corresponding masks, + has shape (num_instances). + filter_thr (float): Score threshold to filter the masks + after matrix nms. Default: -1, which means do not + use filter_thr. + nms_pre (int): The max number of instances to do the matrix nms. + Default: -1, which means do not use nms_pre. + max_num (int, optional): If there are more than max_num masks after + matrix, only top max_num will be kept. Default: -1, which means + do not use max_num. + kernel (str): 'linear' or 'gaussian'. + sigma (float): std in gaussian method. + mask_area (Tensor): The sum of seg_masks. + + Returns: + tuple(Tensor): Processed mask results. + + - scores (Tensor): Updated scores, has shape (n,). + - labels (Tensor): Remained labels, has shape (n,). + - masks (Tensor): Remained masks, has shape (n, w, h). + - keep_inds (Tensor): The indexs number of + the remaining mask in the input mask, has shape (n,). + """ + assert len(labels) == len(masks) == len(scores) + if len(labels) == 0: + return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros( + 0, *masks.shape[-2:]), labels.new_zeros(0) + if mask_area is None: + mask_area = masks.sum((1, 2)).float() + else: + assert len(masks) == len(mask_area) + + # sort and keep top nms_pre + scores, sort_inds = torch.sort(scores, descending=True) + + keep_inds = sort_inds + if nms_pre > 0 and len(sort_inds) > nms_pre: + sort_inds = sort_inds[:nms_pre] + keep_inds = keep_inds[:nms_pre] + scores = scores[:nms_pre] + masks = masks[sort_inds] + mask_area = mask_area[sort_inds] + labels = labels[sort_inds] + + num_masks = len(labels) + flatten_masks = masks.reshape(num_masks, -1).float() + # inter. + inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0)) + expanded_mask_area = mask_area.expand(num_masks, num_masks) + # Upper triangle iou matrix. + iou_matrix = (inter_matrix / + (expanded_mask_area + expanded_mask_area.transpose(1, 0) - + inter_matrix)).triu(diagonal=1) + # label_specific matrix. + expanded_labels = labels.expand(num_masks, num_masks) + # Upper triangle label matrix. + label_matrix = (expanded_labels == expanded_labels.transpose( + 1, 0)).triu(diagonal=1) + + # IoU compensation + compensate_iou, _ = (iou_matrix * label_matrix).max(0) + compensate_iou = compensate_iou.expand(num_masks, + num_masks).transpose(1, 0) + + # IoU decay + decay_iou = iou_matrix * label_matrix + + # Calculate the decay_coefficient + if kernel == 'gaussian': + decay_matrix = torch.exp(-1 * sigma * (decay_iou**2)) + compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2)) + decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0) + elif kernel == 'linear': + decay_matrix = (1 - decay_iou) / (1 - compensate_iou) + decay_coefficient, _ = decay_matrix.min(0) + else: + raise NotImplementedError( + f'{kernel} kernel is not supported in matrix nms!') + # update the score. + scores = scores * decay_coefficient + + if filter_thr > 0: + keep = scores >= filter_thr + keep_inds = keep_inds[keep] + if not keep.any(): + return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros( + 0, *masks.shape[-2:]), labels.new_zeros(0) + masks = masks[keep] + scores = scores[keep] + labels = labels[keep] + + # sort and keep top max_num + scores, sort_inds = torch.sort(scores, descending=True) + keep_inds = keep_inds[sort_inds] + if max_num > 0 and len(sort_inds) > max_num: + sort_inds = sort_inds[:max_num] + keep_inds = keep_inds[:max_num] + scores = scores[:max_num] + masks = masks[sort_inds] + labels = labels[sort_inds] + + return scores, labels, masks, keep_inds diff --git a/mmdet/core/utils/__init__.py b/mmdet/core/utils/__init__.py index 5b2f69fef..bbd909ff6 100644 --- a/mmdet/core/utils/__init__.py +++ b/mmdet/core/utils/__init__.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dist_utils import (DistOptimizerHook, all_reduce_dict, allreduce_grads, reduce_mean) -from .misc import flip_tensor, mask2ndarray, multi_apply, unmap +from .misc import (center_of_mass, flip_tensor, generate_coordinate, + mask2ndarray, multi_apply, unmap) __all__ = [ 'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply', - 'unmap', 'mask2ndarray', 'flip_tensor', 'all_reduce_dict' + 'unmap', 'mask2ndarray', 'flip_tensor', 'all_reduce_dict', + 'center_of_mass', 'generate_coordinate' ] diff --git a/mmdet/core/utils/misc.py b/mmdet/core/utils/misc.py index a2c2ecef7..36bb6883d 100644 --- a/mmdet/core/utils/misc.py +++ b/mmdet/core/utils/misc.py @@ -83,3 +83,46 @@ def flip_tensor(src_tensor, flip_direction): else: out_tensor = torch.flip(src_tensor, [2, 3]) return out_tensor + + +def center_of_mass(mask, esp=1e-6): + """Calculate the centroid coordinates of the mask. + + Args: + mask (Tensor): The mask to be calculated, shape (h, w). + esp (float): Avoid dividing by zero. Default: 1e-6. + + Returns: + tuple[Tensor]: the coordinates of the center point of the mask. + + - center_h (Tensor): the center point of the height. + - center_w (Tensor): the center point of the width. + """ + h, w = mask.shape + grid_h = torch.arange(h, device=mask.device)[:, None] + grid_w = torch.arange(w, device=mask.device) + normalizer = mask.sum().float().clamp(min=esp) + center_h = (mask * grid_h).sum() / normalizer + center_w = (mask * grid_w).sum() / normalizer + return center_h, center_w + + +def generate_coordinate(featmap_sizes, device='cuda'): + """Generate the coordinate. + + Args: + featmap_sizes (tuple): The feature to be calculated, + of shape (N, C, W, H). + device (str): The device where the feature will be put on. + Returns: + coord_feat (Tensor): The coordinate feature, of shape (N, 2, W, H). + """ + + x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device) + y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device) + y, x = torch.meshgrid(y_range, x_range) + y = y.expand([featmap_sizes[0], 1, -1, -1]) + x = x.expand([featmap_sizes[0], 1, -1, -1]) + coord_feat = torch.cat([x, y], 1) + + return coord_feat diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index 328a654aa..c3440fed5 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -28,6 +28,7 @@ from .retina_head import RetinaHead from .retina_sepbn_head import RetinaSepBNHead from .rpn_head import RPNHead from .sabl_retina_head import SABLRetinaHead +from .solo_head import DecoupledSOLOHead, DecoupledSOLOLightHead, SOLOHead from .ssd_head import SSDHead from .vfnet_head import VFNetHead from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead @@ -45,5 +46,6 @@ __all__ = [ 'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead', 'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead', 'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead', - 'CenterNetHead', 'YOLOXHead' + 'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead', + 'DecoupledSOLOLightHead' ] diff --git a/mmdet/models/dense_heads/solo_head.py b/mmdet/models/dense_heads/solo_head.py new file mode 100644 index 000000000..148f819fa --- /dev/null +++ b/mmdet/models/dense_heads/solo_head.py @@ -0,0 +1,1177 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmdet.core import InstanceData, mask_matrix_nms, multi_apply +from mmdet.core.utils import center_of_mass, generate_coordinate +from mmdet.models.builder import HEADS, build_loss +from .base_mask_head import BaseMaskHead + + +@HEADS.register_module() +class SOLOHead(BaseMaskHead): + """SOLO mask head used in `SOLO: Segmenting Objects by Locations. + + `_ + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels. Used in child classes. + Default: 256. + stacked_convs (int): Number of stacking convs of the head. + Default: 4. + strides (tuple): Downsample factor of each feature map. + scale_ranges (tuple[tuple[int, int]]): Area range of multiple + level masks, in the format [(min1, max1), (min2, max2), ...]. + A range of (16, 64) means the area range between (16, 64). + pos_scale (float): Constant scale factor to control the center region. + num_grids (list[int]): Divided image into a uniform grids, each + feature map has a different grid value. The number of output + channels is grid ** 2. Default: [40, 36, 24, 16, 12]. + cls_down_index (int): The index of downsample operation in + classification branch. Default: 0. + loss_mask (dict): Config of mask loss. + loss_cls (dict): Config of classification loss. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: norm_cfg=dict(type='GN', num_groups=32, + requires_grad=True). + train_cfg (dict): Training config of head. + test_cfg (dict): Testing config of head. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__( + self, + num_classes, + in_channels, + feat_channels=256, + stacked_convs=4, + strides=(4, 8, 16, 32, 64), + scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)), + pos_scale=0.2, + num_grids=[40, 36, 24, 16, 12], + cls_down_index=0, + loss_mask=None, + loss_cls=None, + norm_cfg=dict(type='GN', num_groups=32, requires_grad=True), + train_cfg=None, + test_cfg=None, + init_cfg=[ + dict(type='Normal', layer='Conv2d', std=0.01), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_mask_list')), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_cls')) + ], + ): + super(SOLOHead, self).__init__(init_cfg) + self.num_classes = num_classes + self.cls_out_channels = self.num_classes + self.in_channels = in_channels + self.feat_channels = feat_channels + self.stacked_convs = stacked_convs + self.strides = strides + self.num_grids = num_grids + # number of FPN feats + self.num_levels = len(strides) + assert self.num_levels == len(scale_ranges) == len(num_grids) + self.scale_ranges = scale_ranges + self.pos_scale = pos_scale + + self.cls_down_index = cls_down_index + self.loss_cls = build_loss(loss_cls) + self.loss_mask = build_loss(loss_mask) + self.norm_cfg = norm_cfg + self.init_cfg = init_cfg + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self._init_layers() + + def _init_layers(self): + self.mask_convs = nn.ModuleList() + self.cls_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels + 2 if i == 0 else self.feat_channels + self.mask_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + norm_cfg=self.norm_cfg)) + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + norm_cfg=self.norm_cfg)) + self.conv_mask_list = nn.ModuleList() + for num_grid in self.num_grids: + self.conv_mask_list.append( + nn.Conv2d(self.feat_channels, num_grid**2, 1)) + + self.conv_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + + def resize_feats(self, feats): + """Downsample the first feat and upsample last feat in feats.""" + out = [] + for i in range(len(feats)): + if i == 0: + out.append( + F.interpolate(feats[0], scale_factor=0.5, mode='bilinear')) + elif i == len(feats) - 1: + out.append( + F.interpolate( + feats[i], + size=feats[i - 1].shape[-2:], + mode='bilinear')) + else: + out.append(feats[i]) + return out + + def forward(self, feats): + assert len(feats) == self.num_levels + feats = self.resize_feats(feats) + mlvl_mask_preds = [] + mlvl_cls_preds = [] + for i in range(self.num_levels): + x = feats[i] + mask_feat = x + cls_feat = x + # generate and concat the coordinate + coord_feat = generate_coordinate(mask_feat.size(), + mask_feat.device) + mask_feat = torch.cat([mask_feat, coord_feat], 1) + + for mask_layer in (self.mask_convs): + mask_feat = mask_layer(mask_feat) + + mask_feat = F.interpolate( + mask_feat, scale_factor=2, mode='bilinear') + mask_pred = self.conv_mask_list[i](mask_feat) + + # cls branch + for j, cls_layer in enumerate(self.cls_convs): + if j == self.cls_down_index: + num_grid = self.num_grids[i] + cls_feat = F.interpolate( + cls_feat, size=num_grid, mode='bilinear') + cls_feat = cls_layer(cls_feat) + + cls_pred = self.conv_cls(cls_feat) + + if not self.training: + feat_wh = feats[0].size()[-2:] + upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) + mask_pred = F.interpolate( + mask_pred.sigmoid(), size=upsampled_size, mode='bilinear') + cls_pred = cls_pred.sigmoid() + # get local maximum + local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) + keep_mask = local_max[:, :, :-1, :-1] == cls_pred + cls_pred = cls_pred * keep_mask + + mlvl_mask_preds.append(mask_pred) + mlvl_cls_preds.append(cls_pred) + return mlvl_mask_preds, mlvl_cls_preds + + def loss(self, + mlvl_mask_preds, + mlvl_cls_preds, + gt_labels, + gt_masks, + img_metas, + gt_bboxes=None, + **kwargs): + """Calculate the loss of total batch. + + Args: + mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. + Each element in the list has shape + (batch_size, num_grids**2 ,h ,w). + mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element + in the list has shape + (batch_size, num_classes, num_grids ,num_grids). + gt_labels (list[Tensor]): Labels of multiple images. + gt_masks (list[Tensor]): Ground truth masks of multiple images. + Each has shape (num_instances, h, w). + img_metas (list[dict]): Meta information of multiple images. + gt_bboxes (list[Tensor]): Ground truth bboxes of multiple + images. Default: None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_levels = self.num_levels + num_imgs = len(gt_labels) + + featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds] + + # `BoolTensor` in `pos_masks` represent + # whether the corresponding point is + # positive + pos_mask_targets, labels, pos_masks = multi_apply( + self._get_targets_single, + gt_bboxes, + gt_labels, + gt_masks, + featmap_sizes=featmap_sizes) + + # change from the outside list meaning multi images + # to the outside list meaning multi levels + mlvl_pos_mask_targets = [[] for _ in range(num_levels)] + mlvl_pos_mask_preds = [[] for _ in range(num_levels)] + mlvl_pos_masks = [[] for _ in range(num_levels)] + mlvl_labels = [[] for _ in range(num_levels)] + for img_id in range(num_imgs): + assert num_levels == len(pos_mask_targets[img_id]) + for lvl in range(num_levels): + mlvl_pos_mask_targets[lvl].append( + pos_mask_targets[img_id][lvl]) + mlvl_pos_mask_preds[lvl].append( + mlvl_mask_preds[lvl][img_id, pos_masks[img_id][lvl], ...]) + mlvl_pos_masks[lvl].append(pos_masks[img_id][lvl].flatten()) + mlvl_labels[lvl].append(labels[img_id][lvl].flatten()) + + # cat multiple image + temp_mlvl_cls_preds = [] + for lvl in range(num_levels): + mlvl_pos_mask_targets[lvl] = torch.cat( + mlvl_pos_mask_targets[lvl], dim=0) + mlvl_pos_mask_preds[lvl] = torch.cat( + mlvl_pos_mask_preds[lvl], dim=0) + mlvl_pos_masks[lvl] = torch.cat(mlvl_pos_masks[lvl], dim=0) + mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0) + temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute( + 0, 2, 3, 1).reshape(-1, self.cls_out_channels)) + + num_pos = sum(item.sum() for item in mlvl_pos_masks) + # dice loss + loss_mask = [] + for pred, target in zip(mlvl_pos_mask_preds, mlvl_pos_mask_targets): + if pred.size()[0] == 0: + loss_mask.append(pred.sum().unsqueeze(0)) + continue + loss_mask.append( + self.loss_mask(pred, target, reduction_override='none')) + if num_pos > 0: + loss_mask = torch.cat(loss_mask).sum() / num_pos + else: + loss_mask = torch.cat(loss_mask).mean() + + flatten_labels = torch.cat(mlvl_labels) + flatten_cls_preds = torch.cat(temp_mlvl_cls_preds) + loss_cls = self.loss_cls( + flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) + return dict(loss_mask=loss_mask, loss_cls=loss_cls) + + def _get_targets_single(self, + gt_bboxes, + gt_labels, + gt_masks, + featmap_sizes=None): + """Compute targets for predictions of single image. + + Args: + gt_bboxes (Tensor): Ground truth bbox of each instance, + shape (num_gts, 4). + gt_labels (Tensor): Ground truth label of each instance, + shape (num_gts,). + gt_masks (Tensor): Ground truth mask of each instance, + shape (num_gts, h, w). + featmap_sizes (list[:obj:`torch.size`]): Size of each + feature map from feature pyramid, each element + means (feat_h, feat_w). Default: None. + + Returns: + Tuple: Usually returns a tuple containing targets for predictions. + + - mlvl_pos_mask_targets (list[Tensor]): Each element represent + the binary mask targets for positive points in this + level, has shape (num_pos, out_h, out_w). + - mlvl_labels (list[Tensor]): Each element is + classification labels for all + points in this level, has shape + (num_grid, num_grid). + - mlvl_pos_masks (list[Tensor]): Each element is + a `BoolTensor` to represent whether the + corresponding point in single level + is positive, has shape (num_grid **2). + """ + device = gt_labels.device + gt_areas = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) * + (gt_bboxes[:, 3] - gt_bboxes[:, 1])) + + mlvl_pos_mask_targets = [] + mlvl_labels = [] + mlvl_pos_masks = [] + for (lower_bound, upper_bound), stride, featmap_size, num_grid \ + in zip(self.scale_ranges, self.strides, + featmap_sizes, self.num_grids): + + mask_target = torch.zeros( + [num_grid**2, featmap_size[0], featmap_size[1]], + dtype=torch.uint8, + device=device) + # FG cat_id: [0, num_classes -1], BG cat_id: num_classes + labels = torch.zeros([num_grid, num_grid], + dtype=torch.int64, + device=device) + self.num_classes + pos_mask = torch.zeros([num_grid**2], + dtype=torch.bool, + device=device) + + gt_inds = ((gt_areas >= lower_bound) & + (gt_areas <= upper_bound)).nonzero().flatten() + if len(gt_inds) == 0: + mlvl_pos_mask_targets.append( + mask_target.new_zeros(0, featmap_size[0], featmap_size[1])) + mlvl_labels.append(labels) + mlvl_pos_masks.append(pos_mask) + continue + hit_gt_bboxes = gt_bboxes[gt_inds] + hit_gt_labels = gt_labels[gt_inds] + hit_gt_masks = gt_masks[gt_inds, ...] + + pos_w_ranges = 0.5 * (hit_gt_bboxes[:, 2] - + hit_gt_bboxes[:, 0]) * self.pos_scale + pos_h_ranges = 0.5 * (hit_gt_bboxes[:, 3] - + hit_gt_bboxes[:, 1]) * self.pos_scale + + # Make sure hit_gt_masks has a value + valid_mask_flags = hit_gt_masks.sum(dim=-1).sum(dim=-1) > 0 + output_stride = stride / 2 + + for gt_mask, gt_label, pos_h_range, pos_w_range, \ + valid_mask_flag in \ + zip(hit_gt_masks, hit_gt_labels, pos_h_ranges, + pos_w_ranges, valid_mask_flags): + if not valid_mask_flag: + continue + upsampled_size = (featmap_sizes[0][0] * 4, + featmap_sizes[0][1] * 4) + center_h, center_w = center_of_mass(gt_mask) + + coord_w = int( + (center_w / upsampled_size[1]) // (1. / num_grid)) + coord_h = int( + (center_h / upsampled_size[0]) // (1. / num_grid)) + + # left, top, right, down + top_box = max( + 0, + int(((center_h - pos_h_range) / upsampled_size[0]) // + (1. / num_grid))) + down_box = min( + num_grid - 1, + int(((center_h + pos_h_range) / upsampled_size[0]) // + (1. / num_grid))) + left_box = max( + 0, + int(((center_w - pos_w_range) / upsampled_size[1]) // + (1. / num_grid))) + right_box = min( + num_grid - 1, + int(((center_w + pos_w_range) / upsampled_size[1]) // + (1. / num_grid))) + + top = max(top_box, coord_h - 1) + down = min(down_box, coord_h + 1) + left = max(coord_w - 1, left_box) + right = min(right_box, coord_w + 1) + + labels[top:(down + 1), left:(right + 1)] = gt_label + # ins + gt_mask = np.uint8(gt_mask.cpu().numpy()) + # Follow the original implementation, F.interpolate is + # different from cv2 and opencv + gt_mask = mmcv.imrescale(gt_mask, scale=1. / output_stride) + gt_mask = torch.from_numpy(gt_mask).to(device=device) + + for i in range(top, down + 1): + for j in range(left, right + 1): + index = int(i * num_grid + j) + mask_target[index, :gt_mask.shape[0], :gt_mask. + shape[1]] = gt_mask + pos_mask[index] = True + mlvl_pos_mask_targets.append(mask_target[pos_mask]) + mlvl_labels.append(labels) + mlvl_pos_masks.append(pos_mask) + return mlvl_pos_mask_targets, mlvl_labels, mlvl_pos_masks + + def get_results(self, mlvl_mask_preds, mlvl_cls_scores, img_metas, + **kwargs): + """Get multi-image mask results. + + Args: + mlvl_mask_preds (list[Tensor]): Multi-level mask prediction. + Each element in the list has shape + (batch_size, num_grids**2 ,h ,w). + mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element + in the list has shape + (batch_size, num_classes, num_grids ,num_grids). + img_metas (list[dict]): Meta information of all images. + + Returns: + list[:obj:`InstanceData`]: Processed results of multiple + images.Each :obj:`InstanceData` usually contains + following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + mlvl_cls_scores = [ + item.permute(0, 2, 3, 1) for item in mlvl_cls_scores + ] + assert len(mlvl_mask_preds) == len(mlvl_cls_scores) + num_levels = len(mlvl_cls_scores) + + results_list = [] + for img_id in range(len(img_metas)): + cls_pred_list = [ + mlvl_cls_scores[lvl][img_id].view(-1, self.cls_out_channels) + for lvl in range(num_levels) + ] + mask_pred_list = [ + mlvl_mask_preds[lvl][img_id] for lvl in range(num_levels) + ] + + cls_pred_list = torch.cat(cls_pred_list, dim=0) + mask_pred_list = torch.cat(mask_pred_list, dim=0) + + results = self._get_results_single( + cls_pred_list, mask_pred_list, img_meta=img_metas[img_id]) + results_list.append(results) + + return results_list + + def _get_results_single(self, cls_scores, mask_preds, img_meta, cfg=None): + """Get processed mask related results of single image. + + Args: + cls_scores (Tensor): Classification score of all points + in single image, has shape (num_points, num_classes). + mask_preds (Tensor): Mask prediction of all points in + single image, has shape (num_points, feat_h, feat_w). + img_meta (dict): Meta information of corresponding image. + cfg (dict, optional): Config used in test phase. + Default: None. + + Returns: + :obj:`InstanceData`: Processed results of single image. + it usually contains following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + + def empty_results(results, cls_scores): + """Generate a empty results.""" + results.scores = cls_scores.new_ones(0) + results.masks = cls_scores.new_zeros(0, *results.ori_shape[:2]) + results.labels = cls_scores.new_ones(0) + return results + + cfg = self.test_cfg if cfg is None else cfg + assert len(cls_scores) == len(mask_preds) + results = InstanceData(img_meta) + + featmap_size = mask_preds.size()[-2:] + + img_shape = results.img_shape + ori_shape = results.ori_shape + + h, w, _ = img_shape + upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) + + score_mask = (cls_scores > cfg.score_thr) + cls_scores = cls_scores[score_mask] + if len(cls_scores) == 0: + return empty_results(results, cls_scores) + + inds = score_mask.nonzero() + cls_labels = inds[:, 1] + + # Filter the mask mask with an area is smaller than + # stride of corresponding feature level + lvl_interval = cls_labels.new_tensor(self.num_grids).pow(2).cumsum(0) + strides = cls_scores.new_ones(lvl_interval[-1]) + strides[:lvl_interval[0]] *= self.strides[0] + for lvl in range(1, self.num_levels): + strides[lvl_interval[lvl - + 1]:lvl_interval[lvl]] *= self.strides[lvl] + strides = strides[inds[:, 0]] + mask_preds = mask_preds[inds[:, 0]] + + masks = mask_preds > cfg.mask_thr + sum_masks = masks.sum((1, 2)).float() + keep = sum_masks > strides + if keep.sum() == 0: + return empty_results(results, cls_scores) + masks = masks[keep] + mask_preds = mask_preds[keep] + sum_masks = sum_masks[keep] + cls_scores = cls_scores[keep] + cls_labels = cls_labels[keep] + + # maskness. + mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks + cls_scores *= mask_scores + + scores, labels, _, keep_inds = mask_matrix_nms( + masks, + cls_labels, + cls_scores, + mask_area=sum_masks, + nms_pre=cfg.nms_pre, + max_num=cfg.max_per_img, + kernel=cfg.kernel, + sigma=cfg.sigma, + filter_thr=cfg.filter_thr) + mask_preds = mask_preds[keep_inds] + mask_preds = F.interpolate( + mask_preds.unsqueeze(0), size=upsampled_size, + mode='bilinear')[:, :, :h, :w] + mask_preds = F.interpolate( + mask_preds, size=ori_shape[:2], mode='bilinear').squeeze(0) + masks = mask_preds > cfg.mask_thr + + results.masks = masks + results.labels = labels + results.scores = scores + + return results + + +@HEADS.register_module() +class DecoupledSOLOHead(SOLOHead): + """Decoupled SOLO mask head used in `SOLO: Segmenting Objects by Locations. + + `_ + + Args: + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + *args, + init_cfg=[ + dict(type='Normal', layer='Conv2d', std=0.01), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_mask_list_x')), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_mask_list_y')), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_cls')) + ], + **kwargs): + super(DecoupledSOLOHead, self).__init__( + *args, init_cfg=init_cfg, **kwargs) + + def _init_layers(self): + self.mask_convs_x = nn.ModuleList() + self.mask_convs_y = nn.ModuleList() + self.cls_convs = nn.ModuleList() + + for i in range(self.stacked_convs): + chn = self.in_channels + 1 if i == 0 else self.feat_channels + self.mask_convs_x.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + norm_cfg=self.norm_cfg)) + self.mask_convs_y.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + norm_cfg=self.norm_cfg)) + + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + norm_cfg=self.norm_cfg)) + + self.conv_mask_list_x = nn.ModuleList() + self.conv_mask_list_y = nn.ModuleList() + for num_grid in self.num_grids: + self.conv_mask_list_x.append( + nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) + self.conv_mask_list_y.append( + nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) + self.conv_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + + def forward(self, feats): + assert len(feats) == self.num_levels + feats = self.resize_feats(feats) + mask_preds_x = [] + mask_preds_y = [] + cls_preds = [] + for i in range(self.num_levels): + x = feats[i] + mask_feat = x + cls_feat = x + # generate and concat the coordinate + coord_feat = generate_coordinate(mask_feat.size(), + mask_feat.device) + mask_feat_x = torch.cat([mask_feat, coord_feat[:, 0:1, ...]], 1) + mask_feat_y = torch.cat([mask_feat, coord_feat[:, 1:2, ...]], 1) + + for mask_layer_x, mask_layer_y in \ + zip(self.mask_convs_x, self.mask_convs_y): + mask_feat_x = mask_layer_x(mask_feat_x) + mask_feat_y = mask_layer_y(mask_feat_y) + + mask_feat_x = F.interpolate( + mask_feat_x, scale_factor=2, mode='bilinear') + mask_feat_y = F.interpolate( + mask_feat_y, scale_factor=2, mode='bilinear') + + mask_pred_x = self.conv_mask_list_x[i](mask_feat_x) + mask_pred_y = self.conv_mask_list_y[i](mask_feat_y) + + # cls branch + for j, cls_layer in enumerate(self.cls_convs): + if j == self.cls_down_index: + num_grid = self.num_grids[i] + cls_feat = F.interpolate( + cls_feat, size=num_grid, mode='bilinear') + cls_feat = cls_layer(cls_feat) + + cls_pred = self.conv_cls(cls_feat) + + if not self.training: + feat_wh = feats[0].size()[-2:] + upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) + mask_pred_x = F.interpolate( + mask_pred_x.sigmoid(), + size=upsampled_size, + mode='bilinear') + mask_pred_y = F.interpolate( + mask_pred_y.sigmoid(), + size=upsampled_size, + mode='bilinear') + cls_pred = cls_pred.sigmoid() + # get local maximum + local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) + keep_mask = local_max[:, :, :-1, :-1] == cls_pred + cls_pred = cls_pred * keep_mask + + mask_preds_x.append(mask_pred_x) + mask_preds_y.append(mask_pred_y) + cls_preds.append(cls_pred) + return mask_preds_x, mask_preds_y, cls_preds + + def loss(self, + mlvl_mask_preds_x, + mlvl_mask_preds_y, + mlvl_cls_preds, + gt_labels, + gt_masks, + img_metas, + gt_bboxes=None, + **kwargs): + """Calculate the loss of total batch. + + Args: + mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction + from x branch. Each element in the list has shape + (batch_size, num_grids ,h ,w). + mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction + from y branch. Each element in the list has shape + (batch_size, num_grids ,h ,w). + mlvl_cls_preds (list[Tensor]): Multi-level scores. Each element + in the list has shape + (batch_size, num_classes, num_grids ,num_grids). + gt_labels (list[Tensor]): Labels of multiple images. + gt_masks (list[Tensor]): Ground truth masks of multiple images. + Each has shape (num_instances, h, w). + img_metas (list[dict]): Meta information of multiple images. + gt_bboxes (list[Tensor]): Ground truth bboxes of multiple + images. Default: None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_levels = self.num_levels + num_imgs = len(gt_labels) + featmap_sizes = [featmap.size()[-2:] for featmap in mlvl_mask_preds_x] + + pos_mask_targets, labels, \ + xy_pos_indexes = \ + multi_apply(self._get_targets_single, + gt_bboxes, + gt_labels, + gt_masks, + featmap_sizes=featmap_sizes) + + # change from the outside list meaning multi images + # to the outside list meaning multi levels + mlvl_pos_mask_targets = [[] for _ in range(num_levels)] + mlvl_pos_mask_preds_x = [[] for _ in range(num_levels)] + mlvl_pos_mask_preds_y = [[] for _ in range(num_levels)] + mlvl_labels = [[] for _ in range(num_levels)] + for img_id in range(num_imgs): + + for lvl in range(num_levels): + mlvl_pos_mask_targets[lvl].append( + pos_mask_targets[img_id][lvl]) + mlvl_pos_mask_preds_x[lvl].append( + mlvl_mask_preds_x[lvl][img_id, + xy_pos_indexes[img_id][lvl][:, 1]]) + mlvl_pos_mask_preds_y[lvl].append( + mlvl_mask_preds_y[lvl][img_id, + xy_pos_indexes[img_id][lvl][:, 0]]) + mlvl_labels[lvl].append(labels[img_id][lvl].flatten()) + + # cat multiple image + temp_mlvl_cls_preds = [] + for lvl in range(num_levels): + mlvl_pos_mask_targets[lvl] = torch.cat( + mlvl_pos_mask_targets[lvl], dim=0) + mlvl_pos_mask_preds_x[lvl] = torch.cat( + mlvl_pos_mask_preds_x[lvl], dim=0) + mlvl_pos_mask_preds_y[lvl] = torch.cat( + mlvl_pos_mask_preds_y[lvl], dim=0) + mlvl_labels[lvl] = torch.cat(mlvl_labels[lvl], dim=0) + temp_mlvl_cls_preds.append(mlvl_cls_preds[lvl].permute( + 0, 2, 3, 1).reshape(-1, self.cls_out_channels)) + + num_pos = 0. + # dice loss + loss_mask = [] + for pred_x, pred_y, target in \ + zip(mlvl_pos_mask_preds_x, + mlvl_pos_mask_preds_y, mlvl_pos_mask_targets): + num_masks = pred_x.size(0) + if num_masks == 0: + # make sure can get grad + loss_mask.append((pred_x.sum() + pred_y.sum()).unsqueeze(0)) + continue + num_pos += num_masks + pred_mask = pred_y.sigmoid() * pred_x.sigmoid() + loss_mask.append( + self.loss_mask(pred_mask, target, reduction_override='none')) + if num_pos > 0: + loss_mask = torch.cat(loss_mask).sum() / num_pos + else: + loss_mask = torch.cat(loss_mask).mean() + + # cate + flatten_labels = torch.cat(mlvl_labels) + flatten_cls_preds = torch.cat(temp_mlvl_cls_preds) + + loss_cls = self.loss_cls( + flatten_cls_preds, flatten_labels, avg_factor=num_pos + 1) + return dict(loss_mask=loss_mask, loss_cls=loss_cls) + + def _get_targets_single(self, + gt_bboxes, + gt_labels, + gt_masks, + featmap_sizes=None): + """Compute targets for predictions of single image. + + Args: + gt_bboxes (Tensor): Ground truth bbox of each instance, + shape (num_gts, 4). + gt_labels (Tensor): Ground truth label of each instance, + shape (num_gts,). + gt_masks (Tensor): Ground truth mask of each instance, + shape (num_gts, h, w). + featmap_sizes (list[:obj:`torch.size`]): Size of each + feature map from feature pyramid, each element + means (feat_h, feat_w). Default: None. + + Returns: + Tuple: Usually returns a tuple containing targets for predictions. + + - mlvl_pos_mask_targets (list[Tensor]): Each element represent + the binary mask targets for positive points in this + level, has shape (num_pos, out_h, out_w). + - mlvl_labels (list[Tensor]): Each element is + classification labels for all + points in this level, has shape + (num_grid, num_grid). + - mlvl_xy_pos_indexes (list[Tensor]): Each element + in the list contains the index of positive samples in + corresponding level, has shape (num_pos, 2), last + dimension 2 present (index_x, index_y). + """ + mlvl_pos_mask_targets, mlvl_labels, \ + mlvl_pos_masks = \ + super()._get_targets_single(gt_bboxes, gt_labels, gt_masks, + featmap_sizes=featmap_sizes) + + mlvl_xy_pos_indexes = [(item - self.num_classes).nonzero() + for item in mlvl_labels] + + return mlvl_pos_mask_targets, mlvl_labels, mlvl_xy_pos_indexes + + def get_results(self, + mlvl_mask_preds_x, + mlvl_mask_preds_y, + mlvl_cls_scores, + img_metas, + rescale=None, + **kwargs): + """Get multi-image mask results. + + Args: + mlvl_mask_preds_x (list[Tensor]): Multi-level mask prediction + from x branch. Each element in the list has shape + (batch_size, num_grids ,h ,w). + mlvl_mask_preds_y (list[Tensor]): Multi-level mask prediction + from y branch. Each element in the list has shape + (batch_size, num_grids ,h ,w). + mlvl_cls_scores (list[Tensor]): Multi-level scores. Each element + in the list has shape + (batch_size, num_classes ,num_grids ,num_grids). + img_metas (list[dict]): Meta information of all images. + + Returns: + list[:obj:`InstanceData`]: Processed results of multiple + images.Each :obj:`InstanceData` usually contains + following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + mlvl_cls_scores = [ + item.permute(0, 2, 3, 1) for item in mlvl_cls_scores + ] + assert len(mlvl_mask_preds_x) == len(mlvl_cls_scores) + num_levels = len(mlvl_cls_scores) + + results_list = [] + for img_id in range(len(img_metas)): + cls_pred_list = [ + mlvl_cls_scores[i][img_id].view( + -1, self.cls_out_channels).detach() + for i in range(num_levels) + ] + mask_pred_list_x = [ + mlvl_mask_preds_x[i][img_id] for i in range(num_levels) + ] + mask_pred_list_y = [ + mlvl_mask_preds_y[i][img_id] for i in range(num_levels) + ] + + cls_pred_list = torch.cat(cls_pred_list, dim=0) + mask_pred_list_x = torch.cat(mask_pred_list_x, dim=0) + mask_pred_list_y = torch.cat(mask_pred_list_y, dim=0) + + results = self._get_results_single( + cls_pred_list, + mask_pred_list_x, + mask_pred_list_y, + img_meta=img_metas[img_id], + cfg=self.test_cfg) + results_list.append(results) + return results_list + + def _get_results_single(self, cls_scores, mask_preds_x, mask_preds_y, + img_meta, cfg): + """Get processed mask related results of single image. + + Args: + cls_scores (Tensor): Classification score of all points + in single image, has shape (num_points, num_classes). + mask_preds_x (Tensor): Mask prediction of x branch of + all points in single image, has shape + (sum_num_grids, feat_h, feat_w). + mask_preds_y (Tensor): Mask prediction of y branch of + all points in single image, has shape + (sum_num_grids, feat_h, feat_w). + img_meta (dict): Meta information of corresponding image. + cfg (dict): Config used in test phase. + + Returns: + :obj:`InstanceData`: Processed results of single image. + it usually contains following keys. + + - scores (Tensor): Classification scores, has shape + (num_instance,). + - labels (Tensor): Has shape (num_instances,). + - masks (Tensor): Processed mask results, has + shape (num_instances, h, w). + """ + + def empty_results(results, cls_scores): + """Generate a empty results.""" + results.scores = cls_scores.new_ones(0) + results.masks = cls_scores.new_zeros(0, *results.ori_shape[:2]) + results.labels = cls_scores.new_ones(0) + return results + + cfg = self.test_cfg if cfg is None else cfg + + results = InstanceData(img_meta) + img_shape = results.img_shape + ori_shape = results.ori_shape + h, w, _ = img_shape + featmap_size = mask_preds_x.size()[-2:] + upsampled_size = (featmap_size[0] * 4, featmap_size[1] * 4) + + score_mask = (cls_scores > cfg.score_thr) + cls_scores = cls_scores[score_mask] + inds = score_mask.nonzero() + lvl_interval = inds.new_tensor(self.num_grids).pow(2).cumsum(0) + num_all_points = lvl_interval[-1] + lvl_start_index = inds.new_ones(num_all_points) + num_grids = inds.new_ones(num_all_points) + seg_size = inds.new_tensor(self.num_grids).cumsum(0) + mask_lvl_start_index = inds.new_ones(num_all_points) + strides = inds.new_ones(num_all_points) + + lvl_start_index[:lvl_interval[0]] *= 0 + mask_lvl_start_index[:lvl_interval[0]] *= 0 + num_grids[:lvl_interval[0]] *= self.num_grids[0] + strides[:lvl_interval[0]] *= self.strides[0] + + for lvl in range(1, self.num_levels): + lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ + lvl_interval[lvl - 1] + mask_lvl_start_index[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ + seg_size[lvl - 1] + num_grids[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ + self.num_grids[lvl] + strides[lvl_interval[lvl - 1]:lvl_interval[lvl]] *= \ + self.strides[lvl] + + lvl_start_index = lvl_start_index[inds[:, 0]] + mask_lvl_start_index = mask_lvl_start_index[inds[:, 0]] + num_grids = num_grids[inds[:, 0]] + strides = strides[inds[:, 0]] + + y_lvl_offset = (inds[:, 0] - lvl_start_index) // num_grids + x_lvl_offset = (inds[:, 0] - lvl_start_index) % num_grids + y_inds = mask_lvl_start_index + y_lvl_offset + x_inds = mask_lvl_start_index + x_lvl_offset + + cls_labels = inds[:, 1] + mask_preds = mask_preds_x[x_inds, ...] * mask_preds_y[y_inds, ...] + + masks = mask_preds > cfg.mask_thr + sum_masks = masks.sum((1, 2)).float() + keep = sum_masks > strides + if keep.sum() == 0: + return empty_results(results, cls_scores) + + masks = masks[keep] + mask_preds = mask_preds[keep] + sum_masks = sum_masks[keep] + cls_scores = cls_scores[keep] + cls_labels = cls_labels[keep] + + # maskness. + mask_scores = (mask_preds * masks).sum((1, 2)) / sum_masks + cls_scores *= mask_scores + + scores, labels, _, keep_inds = mask_matrix_nms( + masks, + cls_labels, + cls_scores, + mask_area=sum_masks, + nms_pre=cfg.nms_pre, + max_num=cfg.max_per_img, + kernel=cfg.kernel, + sigma=cfg.sigma, + filter_thr=cfg.filter_thr) + mask_preds = mask_preds[keep_inds] + mask_preds = F.interpolate( + mask_preds.unsqueeze(0), size=upsampled_size, + mode='bilinear')[:, :, :h, :w] + mask_preds = F.interpolate( + mask_preds, size=ori_shape[:2], mode='bilinear').squeeze(0) + masks = mask_preds > cfg.mask_thr + + results.masks = masks + results.labels = labels + results.scores = scores + + return results + + +@HEADS.register_module() +class DecoupledSOLOLightHead(DecoupledSOLOHead): + """Decoupled Light SOLO mask head used in `SOLO: Segmenting Objects by + Locations `_ + + Args: + with_dcn (bool): Whether use dcn in mask_convs and cls_convs, + default: False. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + *args, + dcn_cfg=None, + init_cfg=[ + dict(type='Normal', layer='Conv2d', std=0.01), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_mask_list_x')), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_mask_list_y')), + dict( + type='Normal', + std=0.01, + bias_prob=0.01, + override=dict(name='conv_cls')) + ], + **kwargs): + assert dcn_cfg is None or isinstance(dcn_cfg, dict) + self.dcn_cfg = dcn_cfg + super(DecoupledSOLOLightHead, self).__init__( + *args, init_cfg=init_cfg, **kwargs) + + def _init_layers(self): + self.mask_convs = nn.ModuleList() + self.cls_convs = nn.ModuleList() + + for i in range(self.stacked_convs): + if self.dcn_cfg is not None\ + and i == self.stacked_convs - 1: + conv_cfg = self.dcn_cfg + else: + conv_cfg = None + + chn = self.in_channels + 2 if i == 0 else self.feat_channels + self.mask_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg)) + + chn = self.in_channels if i == 0 else self.feat_channels + self.cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=self.norm_cfg)) + + self.conv_mask_list_x = nn.ModuleList() + self.conv_mask_list_y = nn.ModuleList() + for num_grid in self.num_grids: + self.conv_mask_list_x.append( + nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) + self.conv_mask_list_y.append( + nn.Conv2d(self.feat_channels, num_grid, 3, padding=1)) + self.conv_cls = nn.Conv2d( + self.feat_channels, self.cls_out_channels, 3, padding=1) + + def forward(self, feats): + assert len(feats) == self.num_levels + feats = self.resize_feats(feats) + mask_preds_x = [] + mask_preds_y = [] + cls_preds = [] + for i in range(self.num_levels): + x = feats[i] + mask_feat = x + cls_feat = x + # generate and concat the coordinate + coord_feat = generate_coordinate(mask_feat.size(), + mask_feat.device) + mask_feat = torch.cat([mask_feat, coord_feat], 1) + + for mask_layer in self.mask_convs: + mask_feat = mask_layer(mask_feat) + + mask_feat = F.interpolate( + mask_feat, scale_factor=2, mode='bilinear') + + mask_pred_x = self.conv_mask_list_x[i](mask_feat) + mask_pred_y = self.conv_mask_list_y[i](mask_feat) + + # cls branch + for j, cls_layer in enumerate(self.cls_convs): + if j == self.cls_down_index: + num_grid = self.num_grids[i] + cls_feat = F.interpolate( + cls_feat, size=num_grid, mode='bilinear') + cls_feat = cls_layer(cls_feat) + + cls_pred = self.conv_cls(cls_feat) + + if not self.training: + feat_wh = feats[0].size()[-2:] + upsampled_size = (feat_wh[0] * 2, feat_wh[1] * 2) + mask_pred_x = F.interpolate( + mask_pred_x.sigmoid(), + size=upsampled_size, + mode='bilinear') + mask_pred_y = F.interpolate( + mask_pred_y.sigmoid(), + size=upsampled_size, + mode='bilinear') + cls_pred = cls_pred.sigmoid() + # get local maximum + local_max = F.max_pool2d(cls_pred, 2, stride=1, padding=1) + keep_mask = local_max[:, :, :-1, :-1] == cls_pred + cls_pred = cls_pred * keep_mask + + mask_preds_x.append(mask_pred_x) + mask_preds_y.append(mask_pred_y) + cls_preds.append(cls_pred) + return mask_preds_x, mask_preds_y, cls_preds diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py index cad0a770e..08fad5468 100644 --- a/mmdet/models/detectors/__init__.py +++ b/mmdet/models/detectors/__init__.py @@ -28,6 +28,7 @@ from .retinanet import RetinaNet from .rpn import RPN from .scnet import SCNet from .single_stage import SingleStageDetector +from .solo import SOLO from .sparse_rcnn import SparseRCNN from .trident_faster_rcnn import TridentFasterRCNN from .two_stage import TwoStageDetector @@ -43,7 +44,7 @@ __all__ = [ 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', 'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', 'YOLOV3', 'YOLACT', - 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', + 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO', 'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX', 'TwoStagePanopticSegmentor', 'PanopticFPN' ] diff --git a/mmdet/models/detectors/solo.py b/mmdet/models/detectors/solo.py new file mode 100644 index 000000000..9f45d314e --- /dev/null +++ b/mmdet/models/detectors/solo.py @@ -0,0 +1,29 @@ +from ..builder import DETECTORS +from .single_stage_instance_seg import SingleStageInstanceSegmentor + + +@DETECTORS.register_module() +class SOLO(SingleStageInstanceSegmentor): + """`SOLO: Segmenting Objects by Locations + `_ + + """ + + def __init__(self, + backbone, + neck=None, + bbox_head=None, + mask_head=None, + train_cfg=None, + test_cfg=None, + init_cfg=None, + pretrained=None): + super().__init__( + backbone=backbone, + neck=neck, + bbox_head=bbox_head, + mask_head=mask_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + init_cfg=init_cfg, + pretrained=pretrained) diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py index 645db0b4c..068a54d65 100644 --- a/mmdet/models/losses/__init__.py +++ b/mmdet/models/losses/__init__.py @@ -4,6 +4,7 @@ from .ae_loss import AssociativeEmbeddingLoss from .balanced_l1_loss import BalancedL1Loss, balanced_l1_loss from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy) +from .dice_loss import DiceLoss from .focal_loss import FocalLoss, sigmoid_focal_loss from .gaussian_focal_loss import GaussianFocalLoss from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss @@ -27,5 +28,5 @@ __all__ = [ 'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'L1Loss', 'l1_loss', 'isr_p', 'carl_loss', 'AssociativeEmbeddingLoss', 'GaussianFocalLoss', 'QualityFocalLoss', 'DistributionFocalLoss', - 'VarifocalLoss', 'KnowledgeDistillationKLDivLoss', 'SeesawLoss' + 'VarifocalLoss', 'KnowledgeDistillationKLDivLoss', 'SeesawLoss', 'DiceLoss' ] diff --git a/mmdet/models/losses/dice_loss.py b/mmdet/models/losses/dice_loss.py new file mode 100644 index 000000000..0551d143b --- /dev/null +++ b/mmdet/models/losses/dice_loss.py @@ -0,0 +1,123 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn + +from ..builder import LOSSES +from .utils import weight_reduce_loss + + +def dice_loss(pred, + target, + weight=None, + eps=1e-3, + reduction='mean', + avg_factor=None): + """Calculate dice loss, which is proposed in + `V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation `_. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *) + target (torch.Tensor): The learning label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + eps (float): Avoid dividing by zero. Default: 1e-3. + reduction (str, optional): The method used to reduce the loss into + a scalar. Defaults to 'mean'. + Options are "none", "mean" and "sum". + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + """ + + input = pred.reshape(pred.size()[0], -1) + target = target.reshape(target.size()[0], -1).float() + + a = torch.sum(input * target, 1) + b = torch.sum(input * input, 1) + eps + c = torch.sum(target * target, 1) + eps + d = (2 * a) / (b + c) + loss = 1 - d + if weight is not None: + assert weight.ndim == loss.ndim + assert len(weight) == len(pred) + loss = weight_reduce_loss(loss, weight, reduction, avg_factor) + return loss + + +@LOSSES.register_module() +class DiceLoss(nn.Module): + + def __init__(self, + use_sigmoid=True, + activate=True, + reduction='mean', + loss_weight=1.0, + eps=1e-3): + """`Dice Loss, which is proposed in + `V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation `_. + + Args: + use_sigmoid (bool, optional): Whether to the prediction is + used for sigmoid or softmax. Defaults to True. + activate (bool): Whether to activate the predictions inside, + this will disable the inside sigmoid operation. + Defaults to True. + reduction (str, optional): The method used + to reduce the loss. Options are "none", + "mean" and "sum". Defaults to 'mean'. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + eps (float): Avoid dividing by zero. Defaults to 1e-3. + """ + + super(DiceLoss, self).__init__() + self.use_sigmoid = use_sigmoid + self.reduction = reduction + self.loss_weight = loss_weight + self.eps = eps + self.activate = activate + + def forward(self, + pred, + target, + weight=None, + reduction_override=None, + avg_factor=None): + """Forward function. + + Args: + pred (torch.Tensor): The prediction, has a shape (n, *). + target (torch.Tensor): The label of the prediction, + shape (n, *), same shape of pred. + weight (torch.Tensor, optional): The weight of loss for each + prediction, has a shape (n,). Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Options are "none", "mean" and "sum". + + Returns: + torch.Tensor: The calculated loss + """ + + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + + if self.activate: + if self.use_sigmoid: + pred = pred.sigmoid() + else: + raise NotImplementedError + + loss = self.loss_weight * dice_loss( + pred, + target, + weight, + eps=self.eps, + reduction=reduction, + avg_factor=avg_factor) + + return loss diff --git a/model-index.yml b/model-index.yml index aa66f7ce5..aebc55baa 100644 --- a/model-index.yml +++ b/model-index.yml @@ -38,6 +38,7 @@ Import: - configs/nas_fpn/metafile.yml - configs/paa/metafile.yml - configs/pafpn/metafile.yml + - configs/pvt/metafile.yml - configs/pisa/metafile.yml - configs/point_rend/metafile.yml - configs/regnet/metafile.yml @@ -49,6 +50,7 @@ Import: - configs/scnet/metafile.yml - configs/scratch/metafile.yml - configs/sparse_rcnn/metafile.yml + - configs/solo/metafile.yml - configs/ssd/metafile.yml - configs/tridentnet/metafile.yml - configs/vfnet/metafile.yml diff --git a/tests/test_models/test_dense_heads/test_solo_head.py b/tests/test_models/test_dense_heads/test_solo_head.py new file mode 100644 index 000000000..16cb4f7cc --- /dev/null +++ b/tests/test_models/test_dense_heads/test_solo_head.py @@ -0,0 +1,284 @@ +import pytest +import torch + +from mmdet.models.dense_heads import (DecoupledSOLOHead, + DecoupledSOLOLightHead, SOLOHead) + + +def test_solo_head_loss(): + """Tests solo head loss when truth is empty and non-empty.""" + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3) + }] + self = SOLOHead( + num_classes=4, + in_channels=1, + num_grids=[40, 36, 24, 16, 12], + loss_mask=dict(type='DiceLoss', use_sigmoid=True, loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0)) + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [4, 8, 16, 32, 64] + ] + mask_preds, cls_preds = self.forward(feat) + # Test that empty ground truth encourages the network to + # predict background. + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + gt_masks = [torch.empty((0, 550, 550))] + gt_bboxes_ignore = None + empty_gt_losses = self.loss( + mask_preds, + cls_preds, + gt_labels, + gt_masks, + img_metas, + gt_bboxes, + gt_bboxes_ignore=gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_mask_loss = empty_gt_losses['loss_mask'] + empty_cls_loss = empty_gt_losses['loss_cls'] + assert empty_cls_loss.item() > 0, 'cls loss should be non-zero' + assert empty_mask_loss.item() == 0, ( + 'there should be no mask loss when there are no true masks') + + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs. + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + gt_masks = [(torch.rand((1, 256, 256)) > 0.5).float()] + one_gt_losses = self.loss( + mask_preds, + cls_preds, + gt_labels, + gt_masks, + img_metas, + gt_bboxes, + gt_bboxes_ignore=gt_bboxes_ignore) + onegt_mask_loss = one_gt_losses['loss_mask'] + onegt_cls_loss = one_gt_losses['loss_cls'] + assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' + assert onegt_mask_loss.item() > 0, 'mask loss should be non-zero' + + # When the length of num_grids, scale_ranges, and num_levels are not equal. + with pytest.raises(AssertionError): + SOLOHead( + num_classes=4, + in_channels=1, + num_grids=[36, 24, 16, 12], + loss_mask=dict(type='DiceLoss', use_sigmoid=True, loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0)) + + # When input feature length is not equal to num_levels. + with pytest.raises(AssertionError): + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [4, 8, 16, 32] + ] + self.forward(feat) + + +def test_desolo_head_loss(): + """Tests solo head loss when truth is empty and non-empty.""" + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3) + }] + self = DecoupledSOLOHead( + num_classes=4, + in_channels=1, + num_grids=[40, 36, 24, 16, 12], + loss_mask=dict( + type='DiceLoss', use_sigmoid=True, activate=False, + loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0)) + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [4, 8, 16, 32, 64] + ] + mask_preds_x, mask_preds_y, cls_preds = self.forward(feat) + # Test that empty ground truth encourages the network to + # predict background. + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + gt_masks = [torch.empty((0, 550, 550))] + gt_bboxes_ignore = None + empty_gt_losses = self.loss( + mask_preds_x, + mask_preds_y, + cls_preds, + gt_labels, + gt_masks, + img_metas, + gt_bboxes, + gt_bboxes_ignore=gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_mask_loss = empty_gt_losses['loss_mask'] + empty_cls_loss = empty_gt_losses['loss_cls'] + assert empty_cls_loss.item() > 0, 'cls loss should be non-zero' + assert empty_mask_loss.item() == 0, ( + 'there should be no mask loss when there are no true masks') + + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs. + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + gt_masks = [(torch.rand((1, 256, 256)) > 0.5).float()] + one_gt_losses = self.loss( + mask_preds_x, + mask_preds_y, + cls_preds, + gt_labels, + gt_masks, + img_metas, + gt_bboxes, + gt_bboxes_ignore=gt_bboxes_ignore) + onegt_mask_loss = one_gt_losses['loss_mask'] + onegt_cls_loss = one_gt_losses['loss_cls'] + assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' + assert onegt_mask_loss.item() > 0, 'mask loss should be non-zero' + + # When the length of num_grids, scale_ranges, and num_levels are not equal. + with pytest.raises(AssertionError): + DecoupledSOLOHead( + num_classes=4, + in_channels=1, + num_grids=[36, 24, 16, 12], + loss_mask=dict( + type='DiceLoss', + use_sigmoid=True, + activate=False, + loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0)) + + # When input feature length is not equal to num_levels. + with pytest.raises(AssertionError): + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [4, 8, 16, 32] + ] + self.forward(feat) + + +def test_desolo_light_head_loss(): + """Tests solo head loss when truth is empty and non-empty.""" + s = 256 + img_metas = [{ + 'img_shape': (s, s, 3), + 'scale_factor': 1, + 'pad_shape': (s, s, 3) + }] + self = DecoupledSOLOLightHead( + num_classes=4, + in_channels=1, + num_grids=[40, 36, 24, 16, 12], + loss_mask=dict( + type='DiceLoss', use_sigmoid=True, activate=False, + loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0)) + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [4, 8, 16, 32, 64] + ] + mask_preds_x, mask_preds_y, cls_preds = self.forward(feat) + # Test that empty ground truth encourages the network to + # predict background. + gt_bboxes = [torch.empty((0, 4))] + gt_labels = [torch.LongTensor([])] + gt_masks = [torch.empty((0, 550, 550))] + gt_bboxes_ignore = None + empty_gt_losses = self.loss( + mask_preds_x, + mask_preds_y, + cls_preds, + gt_labels, + gt_masks, + img_metas, + gt_bboxes, + gt_bboxes_ignore=gt_bboxes_ignore) + # When there is no truth, the cls loss should be nonzero but there should + # be no box loss. + empty_mask_loss = empty_gt_losses['loss_mask'] + empty_cls_loss = empty_gt_losses['loss_cls'] + assert empty_cls_loss.item() > 0, 'cls loss should be non-zero' + assert empty_mask_loss.item() == 0, ( + 'there should be no mask loss when there are no true masks') + + # When truth is non-empty then both cls and box loss should be nonzero for + # random inputs. + gt_bboxes = [ + torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]), + ] + gt_labels = [torch.LongTensor([2])] + gt_masks = [(torch.rand((1, 256, 256)) > 0.5).float()] + one_gt_losses = self.loss( + mask_preds_x, + mask_preds_y, + cls_preds, + gt_labels, + gt_masks, + img_metas, + gt_bboxes, + gt_bboxes_ignore=gt_bboxes_ignore) + onegt_mask_loss = one_gt_losses['loss_mask'] + onegt_cls_loss = one_gt_losses['loss_cls'] + assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero' + assert onegt_mask_loss.item() > 0, 'mask loss should be non-zero' + + # When the length of num_grids, scale_ranges, and num_levels are not equal. + with pytest.raises(AssertionError): + DecoupledSOLOLightHead( + num_classes=4, + in_channels=1, + num_grids=[36, 24, 16, 12], + loss_mask=dict(type='DiceLoss', use_sigmoid=True, loss_weight=3.0), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0)) + + # When input feature length is not equal to num_levels. + with pytest.raises(AssertionError): + feat = [ + torch.rand(1, 1, s // feat_size, s // feat_size) + for feat_size in [4, 8, 16, 32] + ] + self.forward(feat) diff --git a/tests/test_models/test_loss.py b/tests/test_models/test_loss.py index e86a40fae..a8ebd109a 100644 --- a/tests/test_models/test_loss.py +++ b/tests/test_models/test_loss.py @@ -9,6 +9,7 @@ from mmdet.models.losses import (BalancedL1Loss, CrossEntropyLoss, KnowledgeDistillationKLDivLoss, L1Loss, MSELoss, QualityFocalLoss, SeesawLoss, SmoothL1Loss, VarifocalLoss) +from mmdet.models.losses.dice_loss import DiceLoss from mmdet.models.losses.ghm_loss import GHMC, GHMR from mmdet.models.losses.iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss, IoULoss) @@ -29,7 +30,7 @@ def test_iou_type_loss_zeros_weight(loss_class): BalancedL1Loss, BoundedIoULoss, CIoULoss, CrossEntropyLoss, DIoULoss, FocalLoss, DistributionFocalLoss, MSELoss, SeesawLoss, GaussianFocalLoss, GIoULoss, IoULoss, L1Loss, QualityFocalLoss, VarifocalLoss, GHMR, GHMC, - SmoothL1Loss, KnowledgeDistillationKLDivLoss + SmoothL1Loss, KnowledgeDistillationKLDivLoss, DiceLoss ]) def test_loss_with_reduction_override(loss_class): pred = torch.rand((10, 4)) @@ -163,3 +164,53 @@ def test_loss_with_ignore_index(use_sigmoid): assert torch.allclose(loss, loss_with_ignore) assert torch.allclose(loss, loss_with_forward_ignore) + + +def test_dice_loss(): + loss_class = DiceLoss + pred = torch.rand((10, 4, 4)) + target = torch.rand((10, 4, 4)) + weight = torch.rand((10)) + + # Test loss forward + loss = loss_class()(pred, target) + assert isinstance(loss, torch.Tensor) + + # Test loss forward with weight + loss = loss_class()(pred, target, weight) + assert isinstance(loss, torch.Tensor) + + # Test loss forward with reduction_override + loss = loss_class()(pred, target, reduction_override='mean') + assert isinstance(loss, torch.Tensor) + + # Test loss forward with avg_factor + loss = loss_class()(pred, target, avg_factor=10) + assert isinstance(loss, torch.Tensor) + + with pytest.raises(ValueError): + # loss can evaluate with avg_factor only if + # reduction is None, 'none' or 'mean'. + reduction_override = 'sum' + loss_class()( + pred, target, avg_factor=10, reduction_override=reduction_override) + + # Test loss forward with avg_factor and reduction + for reduction_override in [None, 'none', 'mean']: + loss_class()( + pred, target, avg_factor=10, reduction_override=reduction_override) + assert isinstance(loss, torch.Tensor) + + # Test loss forward with has_acted=False and use_sigmoid=False + with pytest.raises(NotImplementedError): + loss_class(use_sigmoid=False, activate=True)(pred, target) + + # Test loss forward with weight.ndim != loss.ndim + with pytest.raises(AssertionError): + weight = torch.rand((2, 8)) + loss_class()(pred, target, weight) + + # Test loss forward with len(weight) != len(pred) + with pytest.raises(AssertionError): + weight = torch.rand((8)) + loss_class()(pred, target, weight) diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index 05c87bbe7..f01f72f2b 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -5,7 +5,7 @@ import torch from mmdet.core.bbox import distance2bbox from mmdet.core.mask.structures import BitmapMasks, PolygonMasks -from mmdet.core.utils import mask2ndarray +from mmdet.core.utils import center_of_mass, mask2ndarray def dummy_raw_polygon_masks(size): @@ -91,3 +91,20 @@ def test_distance2bbox(): deltas = torch.zeros((2, 0, 4)) out = distance2bbox(rois, deltas, max_shape=(120, 100)) assert rois.shape == out.shape + + +@pytest.mark.parametrize('mask', [ + torch.ones((28, 28)), + torch.zeros((28, 28)), + torch.rand(28, 28) > 0.5, + torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) +]) +def test_center_of_mass(mask): + center_h, center_w = center_of_mass(mask) + if mask.shape[0] == 4: + assert center_h == 1.5 + assert center_w == 1.5 + assert isinstance(center_h, torch.Tensor) \ + and isinstance(center_w, torch.Tensor) + assert 0 <= center_h <= 28 \ + and 0 <= center_w <= 28 diff --git a/tests/test_utils/test_nms.py b/tests/test_utils/test_nms.py new file mode 100644 index 000000000..5fa92dc5a --- /dev/null +++ b/tests/test_utils/test_nms.py @@ -0,0 +1,75 @@ +import pytest +import torch + +from mmdet.core.post_processing import mask_matrix_nms + + +def _create_mask(N, h, w): + masks = torch.rand((N, h, w)) > 0.5 + labels = torch.rand(N) + scores = torch.rand(N) + return masks, labels, scores + + +def test_nms_input_errors(): + with pytest.raises(AssertionError): + mask_matrix_nms( + torch.rand((10, 28, 28)), torch.rand(11), torch.rand(11)) + with pytest.raises(AssertionError): + masks = torch.rand((10, 28, 28)) + mask_matrix_nms( + masks, + torch.rand(11), + torch.rand(11), + mask_area=masks.sum((1, 2)).float()[:8]) + with pytest.raises(NotImplementedError): + mask_matrix_nms( + torch.rand((10, 28, 28)), + torch.rand(10), + torch.rand(10), + kernel='None') + # test an empty results + masks, labels, scores = _create_mask(0, 28, 28) + score, label, mask, keep_ind = \ + mask_matrix_nms(masks, labels, scores) + assert len(score) == len(label) == \ + len(mask) == len(keep_ind) == 0 + + # do not use update_thr, nms_pre and max_num + masks, labels, scores = _create_mask(1000, 28, 28) + score, label, mask, keep_ind = \ + mask_matrix_nms(masks, labels, scores) + assert len(score) == len(label) == \ + len(mask) == len(keep_ind) == 1000 + # only use nms_pre + score, label, mask, keep_ind = \ + mask_matrix_nms(masks, labels, scores, nms_pre=500) + assert len(score) == len(label) == \ + len(mask) == len(keep_ind) == 500 + # use max_num + score, label, mask, keep_ind = \ + mask_matrix_nms(masks, labels, scores, + nms_pre=500, max_num=100) + assert len(score) == len(label) == \ + len(mask) == len(keep_ind) == 100 + + masks, labels, _ = _create_mask(1, 28, 28) + scores = torch.Tensor([1.0]) + masks = masks.expand(1000, 28, 28) + labels = labels.expand(1000) + scores = scores.expand(1000) + + # assert scores is decayed and update_thr is worked + # if with the same mask, label, and all scores = 1 + # the first score will set to 1, others will decay. + score, label, mask, keep_ind = \ + mask_matrix_nms(masks, + labels, + scores, + nms_pre=500, + max_num=100, + kernel='gaussian', + sigma=2.0, + filter_thr=0.5) + assert len(score) == 1 + assert score[0] == 1