[Feature]SOLO: Segmenting Objects by Locations (#5832)

* add SOLO

* add decoupled SOLO

* update decoupled SOLO

* fix linting errors

* format config filename, config content, loss names, norm_cfg

* fix linting errors

* fix matrix_nms and configs

* Add unit tests for SOLO head

* add diceloss

* support mmdet-v2+

* add decopledhead

* clean Chinese comments

* update SOLO

* fix

* delet debug files

* update solo config

* fix bug

* [Fix]: fix some params cannot get grad

* [fix] make sure params can get grad

* init commit for resutls

* add results and instance results

* add docstr

* add more unitets

* add more unitets

* add more unitets

* add more unintest

* add unitet for instance results

* add example

* add meta_info_keys results_keys

* add modified from

* fix unitets

* fix typo

* add instance seg releated base

* forward train for solo

* fix simpletest

* add docstr

* convert to tensor at begin

* refactor yolact traing

* refactor yolact test

* fix test of yolact

* fix empty det of yolact

* fix return tuple

* add format_results

* add testfor formatr

* solo

* add unitest for format_results

* add unitest

* solo

* remove yolact relatede modification

* fix zero bbox

* fix score size

* fix desolo head

* update solo head

* fix error

* rename some attribute

* rename some attribute

* rename decouple

* add doc

* format loss

* reconer decople

* add doc

* fix test

* fix test

* fix doc

* remove points nms

* refactor the post process

* refactor post process of decaouple

* refactor base

* refactor get_target single

* refactor the training of decouple

* refactor test of decouple

* refactor dice loss

* refactor dice

* change to format a dict

* support detection results in test.py

* add base one-stage segmentor

* fix doc

* add onnx export

* add solo config

* add dice loss test unit

* add solo_head test unit

* add more detailed comments

* resolve commnets

* add test unit

* update docstrings and move center of mass to core.utils

* add center of mass test unit

* resolve comments

* resolve commets

* fix rle encode

* fix results

* fix results

* abstract dice loss

* update docstring

* add EPS

* add center of mass test unit

* add eps parameter

* add vis

* add nms test unit

* configs/

add configs

* add desolo light config file

* support desolo light head

* add desolo light config

* add matrix_nms test unit

* fix matrix_nms test unit

* update matrix doc string

* fix error

* fix logic error

* fix logic error

* add comment in test unit

* move has_acted to initialization

* update solo readme

* rename

* revert test

* fix import in example

* fix unitest

* add more uintest

* add more unites

* add more unitest

* rename meta to meta_info

* fix docstr

* fix foc

* fix doc

* add format_results

* fix format results

* fix some default value and function name

* fix desolo light head error

* fix doc and move isntancedata to a new file

* fix typo

* fix unitest in torch 13

* update matrix nms docstring

* fix hard code

* add vis

* add vis

* fix lint

* fix doc

* fix doc

* fix vis

* fix vis

* fix vis

* fix forwardummy doc

* fix doc

* fix comment

* fix doc

* fix order of argument

* add base one-stage segmentor

* fix config files

* fix doc

* fix doc

* support solo

* fix error

* support solo

* rename cls_score

* support solo

* update model zoo

* update docstring

* update docstring

Co-authored-by: WXinlong <wangxinlon@gmail.com>
Co-authored-by: zhangshilong <2392587229zsl@gmail.com>
pull/6194/head
BigDong 4 years ago committed by GitHub
parent b0cd4015c2
commit 2294badd86
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 42
      configs/solo/README.md
  2. 63
      configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py
  3. 28
      configs/solo/decoupled_solo_r50_fpn_1x_coco.py
  4. 25
      configs/solo/decoupled_solo_r50_fpn_3x_coco.py
  5. 115
      configs/solo/metafile.yml
  6. 53
      configs/solo/solo_r50_fpn_1x_coco.py
  7. 28
      configs/solo/solo_r50_fpn_3x_coco.py
  8. 8
      docs/model_zoo.md
  9. 3
      mmdet/core/post_processing/__init__.py
  10. 121
      mmdet/core/post_processing/matrix_nms.py
  11. 6
      mmdet/core/utils/__init__.py
  12. 43
      mmdet/core/utils/misc.py
  13. 4
      mmdet/models/dense_heads/__init__.py
  14. 1177
      mmdet/models/dense_heads/solo_head.py
  15. 3
      mmdet/models/detectors/__init__.py
  16. 29
      mmdet/models/detectors/solo.py
  17. 3
      mmdet/models/losses/__init__.py
  18. 123
      mmdet/models/losses/dice_loss.py
  19. 2
      model-index.yml
  20. 284
      tests/test_models/test_dense_heads/test_solo_head.py
  21. 53
      tests/test_models/test_loss.py
  22. 19
      tests/test_utils/test_misc.py
  23. 75
      tests/test_utils/test_nms.py

@ -0,0 +1,42 @@
# SOLO: Segmenting Objects by Locations
## Introduction
```
@inproceedings{wang2020solo,
title = {{SOLO}: Segmenting Objects by Locations},
author = {Wang, Xinlong and Kong, Tao and Shen, Chunhua and Jiang, Yuning and Li, Lei},
booktitle = {Proc. Eur. Conf. Computer Vision (ECCV)},
year = {2020}
}
```
## Results and Models
### SOLO
| Backbone | Style | MS train | Lr schd | Mem (GB) | Inf time (fps) | mask AP | Download |
|:---------:|:-------:|:--------:|:-------:|:--------:|:--------------:|:------:|:--------:|
| R-50 | pytorch | N | 1x | 8.0 | 14.0 | 33.1 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_1x_coco/solo_r50_fpn_1x_coco_20210821_035055-2290a6b8.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_1x_coco/solo_r50_fpn_1x_coco_20210821_035055.log.json) |
| R-50 | pytorch | Y | 3x | 7.4 | 14.0 | 35.9 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_3x_coco/solo_r50_fpn_3x_coco_20210901_012353-11d224d7.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_3x_coco/solo_r50_fpn_3x_coco_20210901_012353.log.json) |
### Decoupled SOLO
| Backbone | Style | MS train | Lr schd | Mem (GB) | Inf time (fps) | mask AP | Download |
|:---------:|:-------:|:--------:|:-------:|:--------:|:--------------:|:-------:|:--------:|
| R-50 | pytorch | N | 1x | 7.8 | 12.5 | 33.9 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_1x_coco/decoupled_solo_r50_fpn_1x_coco_20210820_233348-6337c589.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_1x_coco/decoupled_solo_r50_fpn_1x_coco_20210820_233348.log.json) |
| R-50 | pytorch | Y | 3x | 7.9 | 12.5 | 36.7 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_3x_coco/decoupled_solo_r50_fpn_3x_coco_20210821_042504-7b3301ec.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_3x_coco/decoupled_solo_r50_fpn_3x_coco_20210821_042504.log.json) |
- Decoupled SOLO has a decoupled head which is different from SOLO head.
Decoupled SOLO serves as an efficient and equivalent variant in accuracy
of SOLO. Please refer to the corresponding config files for details.
### Decoupled Light SOLO
| Backbone | Style | MS train | Lr schd | Mem (GB) | Inf time (fps) | mask AP | Download |
|:---------:|:-------:|:--------:|:-------:|:--------:|:--------------:|:------:|:--------:|
| R-50 | pytorch | Y | 3x | 2.2 | 31.2 | 32.9 | [model](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_light_r50_fpn_3x_coco/decoupled_solo_light_r50_fpn_3x_coco_20210906_142703-e70e226f.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_light_r50_fpn_3x_coco/decoupled_solo_light_r50_fpn_3x_coco_20210906_142703.log.json) |
- Decoupled Light SOLO using decoupled structure similar to Decoupled
SOLO head, with light-weight head and smaller input size, Please refer
to the corresponding config files for details.

@ -0,0 +1,63 @@
_base_ = './decoupled_solo_r50_fpn_3x_coco.py'
# model settings
model = dict(
mask_head=dict(
type='DecoupledSOLOLightHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 64), (32, 128), (64, 256), (128, 512), (256, 2048)),
pos_scale=0.2,
num_grids=[40, 36, 24, 16, 12],
cls_down_index=0,
loss_mask=dict(
type='DiceLoss', use_sigmoid=True, activate=False,
loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)))
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(852, 512), (852, 480), (852, 448), (852, 416), (852, 384),
(852, 352)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(852, 512),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

@ -0,0 +1,28 @@
_base_ = [
'./solo_r50_fpn_1x_coco.py',
]
# model settings
model = dict(
mask_head=dict(
type='DecoupledSOLOHead',
num_classes=80,
in_channels=256,
stacked_convs=7,
feat_channels=256,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
pos_scale=0.2,
num_grids=[40, 36, 24, 16, 12],
cls_down_index=0,
loss_mask=dict(
type='DiceLoss', use_sigmoid=True, activate=False,
loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)))
optimizer = dict(type='SGD', lr=0.01)

@ -0,0 +1,25 @@
_base_ = './solo_r50_fpn_3x_coco.py'
# model settings
model = dict(
mask_head=dict(
type='DecoupledSOLOHead',
num_classes=80,
in_channels=256,
stacked_convs=7,
feat_channels=256,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
pos_scale=0.2,
num_grids=[40, 36, 24, 16, 12],
cls_down_index=0,
loss_mask=dict(
type='DiceLoss', use_sigmoid=True, activate=False,
loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)))

@ -0,0 +1,115 @@
Collections:
- Name: SOLO
Metadata:
Training Data: COCO
Training Techniques:
- SGD with Momentum
- Weight Decay
Training Resources: 8x V100 GPUs
Architecture:
- FPN
- Convolution
- ResNet
Paper: https://arxiv.org/abs/1912.04488
README: configs/solo/README.md
Models:
- Name: decoupled_solo_r50_fpn_1x_coco
In Collection: SOLO
Config: configs/solo/decoupled_solo_r50_fpn_1x_coco.py
Metadata:
Training Memory (GB): 7.8
Epochs: 12
inference time (ms/im):
- value: 116.4
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1333, 800)
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 33.9
Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_1x_coco/decoupled_solo_r50_fpn_1x_coco_20210820_233348-6337c589.pth
- Name: decoupled_solo_r50_fpn_3x_coco
In Collection: SOLO
Config: configs/solo/decoupled_solo_r50_fpn_3x_coco.py
Metadata:
Training Memory (GB): 7.9
Epochs: 36
inference time (ms/im):
- value: 117.2
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1333, 800)
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 36.7
Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_r50_fpn_3x_coco/decoupled_solo_r50_fpn_3x_coco_20210821_042504-7b3301ec.pth
- Name: decoupled_solo_light_r50_fpn_3x_coco
In Collection: SOLO
Config: configs/solo/decoupled_solo_light_r50_fpn_3x_coco.py
Metadata:
Training Memory (GB): 2.2
Epochs: 36
inference time (ms/im):
- value: 35.0
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (852, 512)
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 32.9
Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/decoupled_solo_light_r50_fpn_3x_coco/decoupled_solo_light_r50_fpn_3x_coco_20210906_142703-e70e226f.pth
- Name: solo_r50_fpn_3x_coco
In Collection: SOLO
Config: configs/solo/solo_r50_fpn_3x_coco.py
Metadata:
Training Memory (GB): 7.4
Epochs: 36
inference time (ms/im):
- value: 94.2
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1333, 800)
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 35.9
Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_3x_coco/solo_r50_fpn_3x_coco_20210901_012353-11d224d7.pth
- Name: solo_r50_fpn_1x_coco
In Collection: SOLO
Config: configs/solo/solo_r50_fpn_1x_coco.py
Metadata:
Training Memory (GB): 8.0
Epochs: 12
inference time (ms/im):
- value: 95.1
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1333, 800)
Results:
- Task: Instance Segmentation
Dataset: COCO
Metrics:
mask AP: 33.1
Weights: https://download.openmmlab.com/mmdetection/v2.0/solo/solo_r50_fpn_1x_coco/solo_r50_fpn_1x_coco_20210821_035055-2290a6b8.pth

@ -0,0 +1,53 @@
_base_ = [
'../_base_/datasets/coco_instance.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
# model settings
model = dict(
type='SOLO',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'),
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=0,
num_outs=5),
mask_head=dict(
type='SOLOHead',
num_classes=80,
in_channels=256,
stacked_convs=7,
feat_channels=256,
strides=[8, 8, 16, 32, 32],
scale_ranges=((1, 96), (48, 192), (96, 384), (192, 768), (384, 2048)),
pos_scale=0.2,
num_grids=[40, 36, 24, 16, 12],
cls_down_index=0,
loss_mask=dict(type='DiceLoss', use_sigmoid=True, loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True)),
# model training and testing settings
test_cfg=dict(
nms_pre=500,
score_thr=0.1,
mask_thr=0.5,
filter_thr=0.05,
kernel='gaussian', # gaussian/linear
sigma=2.0,
max_per_img=100))
# optimizer
optimizer = dict(type='SGD', lr=0.01)

@ -0,0 +1,28 @@
_base_ = './solo_r50_fpn_1x_coco.py'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=[(1333, 800), (1333, 768), (1333, 736), (1333, 704),
(1333, 672), (1333, 640)],
multiscale_mode='value',
keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
data = dict(train=dict(pipeline=train_pipeline))
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3,
step=[27, 33])
runner = dict(type='EpochBasedRunner', max_epochs=36)

@ -230,6 +230,14 @@ Please refer to [CenterNet](https://github.com/open-mmlab/mmdetection/blob/maste
Please refer to [YOLOX](https://github.com/open-mmlab/mmdetection/blob/master/configs/yolox) for details.
### PVT
Please refer to [PVT](https://github.com/open-mmlab/mmdetection/blob/master/configs/pvt) for details.
### SOLO
Please refer to [SOLO](https://github.com/open-mmlab/mmdetection/blob/master/configs/solo) for details.
### Other datasets
We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face).

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bbox_nms import fast_nms, multiclass_nms
from .matrix_nms import mask_matrix_nms
from .merge_augs import (merge_aug_bboxes, merge_aug_masks,
merge_aug_proposals, merge_aug_scores)
__all__ = [
'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes',
'merge_aug_scores', 'merge_aug_masks', 'fast_nms'
'merge_aug_scores', 'merge_aug_masks', 'mask_matrix_nms', 'fast_nms'
]

@ -0,0 +1,121 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def mask_matrix_nms(masks,
labels,
scores,
filter_thr=-1,
nms_pre=-1,
max_num=-1,
kernel='gaussian',
sigma=2.0,
mask_area=None):
"""Matrix NMS for multi-class masks.
Args:
masks (Tensor): Has shape (num_instances, h, w)
labels (Tensor): Labels of corresponding masks,
has shape (num_instances,).
scores (Tensor): Mask scores of corresponding masks,
has shape (num_instances).
filter_thr (float): Score threshold to filter the masks
after matrix nms. Default: -1, which means do not
use filter_thr.
nms_pre (int): The max number of instances to do the matrix nms.
Default: -1, which means do not use nms_pre.
max_num (int, optional): If there are more than max_num masks after
matrix, only top max_num will be kept. Default: -1, which means
do not use max_num.
kernel (str): 'linear' or 'gaussian'.
sigma (float): std in gaussian method.
mask_area (Tensor): The sum of seg_masks.
Returns:
tuple(Tensor): Processed mask results.
- scores (Tensor): Updated scores, has shape (n,).
- labels (Tensor): Remained labels, has shape (n,).
- masks (Tensor): Remained masks, has shape (n, w, h).
- keep_inds (Tensor): The indexs number of
the remaining mask in the input mask, has shape (n,).
"""
assert len(labels) == len(masks) == len(scores)
if len(labels) == 0:
return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros(
0, *masks.shape[-2:]), labels.new_zeros(0)
if mask_area is None:
mask_area = masks.sum((1, 2)).float()
else:
assert len(masks) == len(mask_area)
# sort and keep top nms_pre
scores, sort_inds = torch.sort(scores, descending=True)
keep_inds = sort_inds
if nms_pre > 0 and len(sort_inds) > nms_pre:
sort_inds = sort_inds[:nms_pre]
keep_inds = keep_inds[:nms_pre]
scores = scores[:nms_pre]
masks = masks[sort_inds]
mask_area = mask_area[sort_inds]
labels = labels[sort_inds]
num_masks = len(labels)
flatten_masks = masks.reshape(num_masks, -1).float()
# inter.
inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0))
expanded_mask_area = mask_area.expand(num_masks, num_masks)
# Upper triangle iou matrix.
iou_matrix = (inter_matrix /
(expanded_mask_area + expanded_mask_area.transpose(1, 0) -
inter_matrix)).triu(diagonal=1)
# label_specific matrix.
expanded_labels = labels.expand(num_masks, num_masks)
# Upper triangle label matrix.
label_matrix = (expanded_labels == expanded_labels.transpose(
1, 0)).triu(diagonal=1)
# IoU compensation
compensate_iou, _ = (iou_matrix * label_matrix).max(0)
compensate_iou = compensate_iou.expand(num_masks,
num_masks).transpose(1, 0)
# IoU decay
decay_iou = iou_matrix * label_matrix
# Calculate the decay_coefficient
if kernel == 'gaussian':
decay_matrix = torch.exp(-1 * sigma * (decay_iou**2))
compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2))
decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0)
elif kernel == 'linear':
decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
decay_coefficient, _ = decay_matrix.min(0)
else:
raise NotImplementedError(
f'{kernel} kernel is not supported in matrix nms!')
# update the score.
scores = scores * decay_coefficient
if filter_thr > 0:
keep = scores >= filter_thr
keep_inds = keep_inds[keep]
if not keep.any():
return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros(
0, *masks.shape[-2:]), labels.new_zeros(0)
masks = masks[keep]
scores = scores[keep]
labels = labels[keep]
# sort and keep top max_num
scores, sort_inds = torch.sort(scores, descending=True)
keep_inds = keep_inds[sort_inds]
if max_num > 0 and len(sort_inds) > max_num:
sort_inds = sort_inds[:max_num]
keep_inds = keep_inds[:max_num]
scores = scores[:max_num]
masks = masks[sort_inds]
labels = labels[sort_inds]
return scores, labels, masks, keep_inds

@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_utils import (DistOptimizerHook, all_reduce_dict, allreduce_grads,
reduce_mean)
from .misc import flip_tensor, mask2ndarray, multi_apply, unmap
from .misc import (center_of_mass, flip_tensor, generate_coordinate,
mask2ndarray, multi_apply, unmap)
__all__ = [
'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply',
'unmap', 'mask2ndarray', 'flip_tensor', 'all_reduce_dict'
'unmap', 'mask2ndarray', 'flip_tensor', 'all_reduce_dict',
'center_of_mass', 'generate_coordinate'
]

@ -83,3 +83,46 @@ def flip_tensor(src_tensor, flip_direction):
else:
out_tensor = torch.flip(src_tensor, [2, 3])
return out_tensor
def center_of_mass(mask, esp=1e-6):
"""Calculate the centroid coordinates of the mask.
Args:
mask (Tensor): The mask to be calculated, shape (h, w).
esp (float): Avoid dividing by zero. Default: 1e-6.
Returns:
tuple[Tensor]: the coordinates of the center point of the mask.
- center_h (Tensor): the center point of the height.
- center_w (Tensor): the center point of the width.
"""
h, w = mask.shape
grid_h = torch.arange(h, device=mask.device)[:, None]
grid_w = torch.arange(w, device=mask.device)
normalizer = mask.sum().float().clamp(min=esp)
center_h = (mask * grid_h).sum() / normalizer
center_w = (mask * grid_w).sum() / normalizer
return center_h, center_w
def generate_coordinate(featmap_sizes, device='cuda'):
"""Generate the coordinate.
Args:
featmap_sizes (tuple): The feature to be calculated,
of shape (N, C, W, H).
device (str): The device where the feature will be put on.
Returns:
coord_feat (Tensor): The coordinate feature, of shape (N, 2, W, H).
"""
x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device)
y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device)
y, x = torch.meshgrid(y_range, x_range)
y = y.expand([featmap_sizes[0], 1, -1, -1])
x = x.expand([featmap_sizes[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
return coord_feat

@ -28,6 +28,7 @@ from .retina_head import RetinaHead
from .retina_sepbn_head import RetinaSepBNHead
from .rpn_head import RPNHead
from .sabl_retina_head import SABLRetinaHead
from .solo_head import DecoupledSOLOHead, DecoupledSOLOLightHead, SOLOHead
from .ssd_head import SSDHead
from .vfnet_head import VFNetHead
from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead
@ -45,5 +46,6 @@ __all__ = [
'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead',
'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead',
'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead',
'CenterNetHead', 'YOLOXHead'
'SOLOHead', 'DecoupledSOLOHead', 'CenterNetHead', 'YOLOXHead',
'DecoupledSOLOLightHead'
]

File diff suppressed because it is too large Load Diff

@ -28,6 +28,7 @@ from .retinanet import RetinaNet
from .rpn import RPN
from .scnet import SCNet
from .single_stage import SingleStageDetector
from .solo import SOLO
from .sparse_rcnn import SparseRCNN
from .trident_faster_rcnn import TridentFasterRCNN
from .two_stage import TwoStageDetector
@ -43,7 +44,7 @@ __all__ = [
'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade', 'RetinaNet', 'FCOS',
'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector', 'FOVEA', 'FSAF',
'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA', 'YOLOV3', 'YOLACT',
'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet',
'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN', 'SCNet', 'SOLO',
'DeformableDETR', 'AutoAssign', 'YOLOF', 'CenterNet', 'YOLOX',
'TwoStagePanopticSegmentor', 'PanopticFPN'
]

@ -0,0 +1,29 @@
from ..builder import DETECTORS
from .single_stage_instance_seg import SingleStageInstanceSegmentor
@DETECTORS.register_module()
class SOLO(SingleStageInstanceSegmentor):
"""`SOLO: Segmenting Objects by Locations
<https://arxiv.org/abs/1912.04488>`_
"""
def __init__(self,
backbone,
neck=None,
bbox_head=None,
mask_head=None,
train_cfg=None,
test_cfg=None,
init_cfg=None,
pretrained=None):
super().__init__(
backbone=backbone,
neck=neck,
bbox_head=bbox_head,
mask_head=mask_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
init_cfg=init_cfg,
pretrained=pretrained)

@ -4,6 +4,7 @@ from .ae_loss import AssociativeEmbeddingLoss
from .balanced_l1_loss import BalancedL1Loss, balanced_l1_loss
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy, mask_cross_entropy)
from .dice_loss import DiceLoss
from .focal_loss import FocalLoss, sigmoid_focal_loss
from .gaussian_focal_loss import GaussianFocalLoss
from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss
@ -27,5 +28,5 @@ __all__ = [
'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'L1Loss',
'l1_loss', 'isr_p', 'carl_loss', 'AssociativeEmbeddingLoss',
'GaussianFocalLoss', 'QualityFocalLoss', 'DistributionFocalLoss',
'VarifocalLoss', 'KnowledgeDistillationKLDivLoss', 'SeesawLoss'
'VarifocalLoss', 'KnowledgeDistillationKLDivLoss', 'SeesawLoss', 'DiceLoss'
]

@ -0,0 +1,123 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from ..builder import LOSSES
from .utils import weight_reduce_loss
def dice_loss(pred,
target,
weight=None,
eps=1e-3,
reduction='mean',
avg_factor=None):
"""Calculate dice loss, which is proposed in
`V-Net: Fully Convolutional Neural Networks for Volumetric
Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *)
target (torch.Tensor): The learning label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
eps (float): Avoid dividing by zero. Default: 1e-3.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
input = pred.reshape(pred.size()[0], -1)
target = target.reshape(target.size()[0], -1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + eps
c = torch.sum(target * target, 1) + eps
d = (2 * a) / (b + c)
loss = 1 - d
if weight is not None:
assert weight.ndim == loss.ndim
assert len(weight) == len(pred)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module()
class DiceLoss(nn.Module):
def __init__(self,
use_sigmoid=True,
activate=True,
reduction='mean',
loss_weight=1.0,
eps=1e-3):
"""`Dice Loss, which is proposed in
`V-Net: Fully Convolutional Neural Networks for Volumetric
Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
activate (bool): Whether to activate the predictions inside,
this will disable the inside sigmoid operation.
Defaults to True.
reduction (str, optional): The method used
to reduce the loss. Options are "none",
"mean" and "sum". Defaults to 'mean'.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
eps (float): Avoid dividing by zero. Defaults to 1e-3.
"""
super(DiceLoss, self).__init__()
self.use_sigmoid = use_sigmoid
self.reduction = reduction
self.loss_weight = loss_weight
self.eps = eps
self.activate = activate
def forward(self,
pred,
target,
weight=None,
reduction_override=None,
avg_factor=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *).
target (torch.Tensor): The label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
if self.activate:
if self.use_sigmoid:
pred = pred.sigmoid()
else:
raise NotImplementedError
loss = self.loss_weight * dice_loss(
pred,
target,
weight,
eps=self.eps,
reduction=reduction,
avg_factor=avg_factor)
return loss

@ -38,6 +38,7 @@ Import:
- configs/nas_fpn/metafile.yml
- configs/paa/metafile.yml
- configs/pafpn/metafile.yml
- configs/pvt/metafile.yml
- configs/pisa/metafile.yml
- configs/point_rend/metafile.yml
- configs/regnet/metafile.yml
@ -49,6 +50,7 @@ Import:
- configs/scnet/metafile.yml
- configs/scratch/metafile.yml
- configs/sparse_rcnn/metafile.yml
- configs/solo/metafile.yml
- configs/ssd/metafile.yml
- configs/tridentnet/metafile.yml
- configs/vfnet/metafile.yml

@ -0,0 +1,284 @@
import pytest
import torch
from mmdet.models.dense_heads import (DecoupledSOLOHead,
DecoupledSOLOLightHead, SOLOHead)
def test_solo_head_loss():
"""Tests solo head loss when truth is empty and non-empty."""
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'scale_factor': 1,
'pad_shape': (s, s, 3)
}]
self = SOLOHead(
num_classes=4,
in_channels=1,
num_grids=[40, 36, 24, 16, 12],
loss_mask=dict(type='DiceLoss', use_sigmoid=True, loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0))
feat = [
torch.rand(1, 1, s // feat_size, s // feat_size)
for feat_size in [4, 8, 16, 32, 64]
]
mask_preds, cls_preds = self.forward(feat)
# Test that empty ground truth encourages the network to
# predict background.
gt_bboxes = [torch.empty((0, 4))]
gt_labels = [torch.LongTensor([])]
gt_masks = [torch.empty((0, 550, 550))]
gt_bboxes_ignore = None
empty_gt_losses = self.loss(
mask_preds,
cls_preds,
gt_labels,
gt_masks,
img_metas,
gt_bboxes,
gt_bboxes_ignore=gt_bboxes_ignore)
# When there is no truth, the cls loss should be nonzero but there should
# be no box loss.
empty_mask_loss = empty_gt_losses['loss_mask']
empty_cls_loss = empty_gt_losses['loss_cls']
assert empty_cls_loss.item() > 0, 'cls loss should be non-zero'
assert empty_mask_loss.item() == 0, (
'there should be no mask loss when there are no true masks')
# When truth is non-empty then both cls and box loss should be nonzero for
# random inputs.
gt_bboxes = [
torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
]
gt_labels = [torch.LongTensor([2])]
gt_masks = [(torch.rand((1, 256, 256)) > 0.5).float()]
one_gt_losses = self.loss(
mask_preds,
cls_preds,
gt_labels,
gt_masks,
img_metas,
gt_bboxes,
gt_bboxes_ignore=gt_bboxes_ignore)
onegt_mask_loss = one_gt_losses['loss_mask']
onegt_cls_loss = one_gt_losses['loss_cls']
assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero'
assert onegt_mask_loss.item() > 0, 'mask loss should be non-zero'
# When the length of num_grids, scale_ranges, and num_levels are not equal.
with pytest.raises(AssertionError):
SOLOHead(
num_classes=4,
in_channels=1,
num_grids=[36, 24, 16, 12],
loss_mask=dict(type='DiceLoss', use_sigmoid=True, loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0))
# When input feature length is not equal to num_levels.
with pytest.raises(AssertionError):
feat = [
torch.rand(1, 1, s // feat_size, s // feat_size)
for feat_size in [4, 8, 16, 32]
]
self.forward(feat)
def test_desolo_head_loss():
"""Tests solo head loss when truth is empty and non-empty."""
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'scale_factor': 1,
'pad_shape': (s, s, 3)
}]
self = DecoupledSOLOHead(
num_classes=4,
in_channels=1,
num_grids=[40, 36, 24, 16, 12],
loss_mask=dict(
type='DiceLoss', use_sigmoid=True, activate=False,
loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0))
feat = [
torch.rand(1, 1, s // feat_size, s // feat_size)
for feat_size in [4, 8, 16, 32, 64]
]
mask_preds_x, mask_preds_y, cls_preds = self.forward(feat)
# Test that empty ground truth encourages the network to
# predict background.
gt_bboxes = [torch.empty((0, 4))]
gt_labels = [torch.LongTensor([])]
gt_masks = [torch.empty((0, 550, 550))]
gt_bboxes_ignore = None
empty_gt_losses = self.loss(
mask_preds_x,
mask_preds_y,
cls_preds,
gt_labels,
gt_masks,
img_metas,
gt_bboxes,
gt_bboxes_ignore=gt_bboxes_ignore)
# When there is no truth, the cls loss should be nonzero but there should
# be no box loss.
empty_mask_loss = empty_gt_losses['loss_mask']
empty_cls_loss = empty_gt_losses['loss_cls']
assert empty_cls_loss.item() > 0, 'cls loss should be non-zero'
assert empty_mask_loss.item() == 0, (
'there should be no mask loss when there are no true masks')
# When truth is non-empty then both cls and box loss should be nonzero for
# random inputs.
gt_bboxes = [
torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
]
gt_labels = [torch.LongTensor([2])]
gt_masks = [(torch.rand((1, 256, 256)) > 0.5).float()]
one_gt_losses = self.loss(
mask_preds_x,
mask_preds_y,
cls_preds,
gt_labels,
gt_masks,
img_metas,
gt_bboxes,
gt_bboxes_ignore=gt_bboxes_ignore)
onegt_mask_loss = one_gt_losses['loss_mask']
onegt_cls_loss = one_gt_losses['loss_cls']
assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero'
assert onegt_mask_loss.item() > 0, 'mask loss should be non-zero'
# When the length of num_grids, scale_ranges, and num_levels are not equal.
with pytest.raises(AssertionError):
DecoupledSOLOHead(
num_classes=4,
in_channels=1,
num_grids=[36, 24, 16, 12],
loss_mask=dict(
type='DiceLoss',
use_sigmoid=True,
activate=False,
loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0))
# When input feature length is not equal to num_levels.
with pytest.raises(AssertionError):
feat = [
torch.rand(1, 1, s // feat_size, s // feat_size)
for feat_size in [4, 8, 16, 32]
]
self.forward(feat)
def test_desolo_light_head_loss():
"""Tests solo head loss when truth is empty and non-empty."""
s = 256
img_metas = [{
'img_shape': (s, s, 3),
'scale_factor': 1,
'pad_shape': (s, s, 3)
}]
self = DecoupledSOLOLightHead(
num_classes=4,
in_channels=1,
num_grids=[40, 36, 24, 16, 12],
loss_mask=dict(
type='DiceLoss', use_sigmoid=True, activate=False,
loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0))
feat = [
torch.rand(1, 1, s // feat_size, s // feat_size)
for feat_size in [4, 8, 16, 32, 64]
]
mask_preds_x, mask_preds_y, cls_preds = self.forward(feat)
# Test that empty ground truth encourages the network to
# predict background.
gt_bboxes = [torch.empty((0, 4))]
gt_labels = [torch.LongTensor([])]
gt_masks = [torch.empty((0, 550, 550))]
gt_bboxes_ignore = None
empty_gt_losses = self.loss(
mask_preds_x,
mask_preds_y,
cls_preds,
gt_labels,
gt_masks,
img_metas,
gt_bboxes,
gt_bboxes_ignore=gt_bboxes_ignore)
# When there is no truth, the cls loss should be nonzero but there should
# be no box loss.
empty_mask_loss = empty_gt_losses['loss_mask']
empty_cls_loss = empty_gt_losses['loss_cls']
assert empty_cls_loss.item() > 0, 'cls loss should be non-zero'
assert empty_mask_loss.item() == 0, (
'there should be no mask loss when there are no true masks')
# When truth is non-empty then both cls and box loss should be nonzero for
# random inputs.
gt_bboxes = [
torch.Tensor([[23.6667, 23.8757, 238.6326, 151.8874]]),
]
gt_labels = [torch.LongTensor([2])]
gt_masks = [(torch.rand((1, 256, 256)) > 0.5).float()]
one_gt_losses = self.loss(
mask_preds_x,
mask_preds_y,
cls_preds,
gt_labels,
gt_masks,
img_metas,
gt_bboxes,
gt_bboxes_ignore=gt_bboxes_ignore)
onegt_mask_loss = one_gt_losses['loss_mask']
onegt_cls_loss = one_gt_losses['loss_cls']
assert onegt_cls_loss.item() > 0, 'cls loss should be non-zero'
assert onegt_mask_loss.item() > 0, 'mask loss should be non-zero'
# When the length of num_grids, scale_ranges, and num_levels are not equal.
with pytest.raises(AssertionError):
DecoupledSOLOLightHead(
num_classes=4,
in_channels=1,
num_grids=[36, 24, 16, 12],
loss_mask=dict(type='DiceLoss', use_sigmoid=True, loss_weight=3.0),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0))
# When input feature length is not equal to num_levels.
with pytest.raises(AssertionError):
feat = [
torch.rand(1, 1, s // feat_size, s // feat_size)
for feat_size in [4, 8, 16, 32]
]
self.forward(feat)

@ -9,6 +9,7 @@ from mmdet.models.losses import (BalancedL1Loss, CrossEntropyLoss,
KnowledgeDistillationKLDivLoss, L1Loss,
MSELoss, QualityFocalLoss, SeesawLoss,
SmoothL1Loss, VarifocalLoss)
from mmdet.models.losses.dice_loss import DiceLoss
from mmdet.models.losses.ghm_loss import GHMC, GHMR
from mmdet.models.losses.iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss,
GIoULoss, IoULoss)
@ -29,7 +30,7 @@ def test_iou_type_loss_zeros_weight(loss_class):
BalancedL1Loss, BoundedIoULoss, CIoULoss, CrossEntropyLoss, DIoULoss,
FocalLoss, DistributionFocalLoss, MSELoss, SeesawLoss, GaussianFocalLoss,
GIoULoss, IoULoss, L1Loss, QualityFocalLoss, VarifocalLoss, GHMR, GHMC,
SmoothL1Loss, KnowledgeDistillationKLDivLoss
SmoothL1Loss, KnowledgeDistillationKLDivLoss, DiceLoss
])
def test_loss_with_reduction_override(loss_class):
pred = torch.rand((10, 4))
@ -163,3 +164,53 @@ def test_loss_with_ignore_index(use_sigmoid):
assert torch.allclose(loss, loss_with_ignore)
assert torch.allclose(loss, loss_with_forward_ignore)
def test_dice_loss():
loss_class = DiceLoss
pred = torch.rand((10, 4, 4))
target = torch.rand((10, 4, 4))
weight = torch.rand((10))
# Test loss forward
loss = loss_class()(pred, target)
assert isinstance(loss, torch.Tensor)
# Test loss forward with weight
loss = loss_class()(pred, target, weight)
assert isinstance(loss, torch.Tensor)
# Test loss forward with reduction_override
loss = loss_class()(pred, target, reduction_override='mean')
assert isinstance(loss, torch.Tensor)
# Test loss forward with avg_factor
loss = loss_class()(pred, target, avg_factor=10)
assert isinstance(loss, torch.Tensor)
with pytest.raises(ValueError):
# loss can evaluate with avg_factor only if
# reduction is None, 'none' or 'mean'.
reduction_override = 'sum'
loss_class()(
pred, target, avg_factor=10, reduction_override=reduction_override)
# Test loss forward with avg_factor and reduction
for reduction_override in [None, 'none', 'mean']:
loss_class()(
pred, target, avg_factor=10, reduction_override=reduction_override)
assert isinstance(loss, torch.Tensor)
# Test loss forward with has_acted=False and use_sigmoid=False
with pytest.raises(NotImplementedError):
loss_class(use_sigmoid=False, activate=True)(pred, target)
# Test loss forward with weight.ndim != loss.ndim
with pytest.raises(AssertionError):
weight = torch.rand((2, 8))
loss_class()(pred, target, weight)
# Test loss forward with len(weight) != len(pred)
with pytest.raises(AssertionError):
weight = torch.rand((8))
loss_class()(pred, target, weight)

@ -5,7 +5,7 @@ import torch
from mmdet.core.bbox import distance2bbox
from mmdet.core.mask.structures import BitmapMasks, PolygonMasks
from mmdet.core.utils import mask2ndarray
from mmdet.core.utils import center_of_mass, mask2ndarray
def dummy_raw_polygon_masks(size):
@ -91,3 +91,20 @@ def test_distance2bbox():
deltas = torch.zeros((2, 0, 4))
out = distance2bbox(rois, deltas, max_shape=(120, 100))
assert rois.shape == out.shape
@pytest.mark.parametrize('mask', [
torch.ones((28, 28)),
torch.zeros((28, 28)),
torch.rand(28, 28) > 0.5,
torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
])
def test_center_of_mass(mask):
center_h, center_w = center_of_mass(mask)
if mask.shape[0] == 4:
assert center_h == 1.5
assert center_w == 1.5
assert isinstance(center_h, torch.Tensor) \
and isinstance(center_w, torch.Tensor)
assert 0 <= center_h <= 28 \
and 0 <= center_w <= 28

@ -0,0 +1,75 @@
import pytest
import torch
from mmdet.core.post_processing import mask_matrix_nms
def _create_mask(N, h, w):
masks = torch.rand((N, h, w)) > 0.5
labels = torch.rand(N)
scores = torch.rand(N)
return masks, labels, scores
def test_nms_input_errors():
with pytest.raises(AssertionError):
mask_matrix_nms(
torch.rand((10, 28, 28)), torch.rand(11), torch.rand(11))
with pytest.raises(AssertionError):
masks = torch.rand((10, 28, 28))
mask_matrix_nms(
masks,
torch.rand(11),
torch.rand(11),
mask_area=masks.sum((1, 2)).float()[:8])
with pytest.raises(NotImplementedError):
mask_matrix_nms(
torch.rand((10, 28, 28)),
torch.rand(10),
torch.rand(10),
kernel='None')
# test an empty results
masks, labels, scores = _create_mask(0, 28, 28)
score, label, mask, keep_ind = \
mask_matrix_nms(masks, labels, scores)
assert len(score) == len(label) == \
len(mask) == len(keep_ind) == 0
# do not use update_thr, nms_pre and max_num
masks, labels, scores = _create_mask(1000, 28, 28)
score, label, mask, keep_ind = \
mask_matrix_nms(masks, labels, scores)
assert len(score) == len(label) == \
len(mask) == len(keep_ind) == 1000
# only use nms_pre
score, label, mask, keep_ind = \
mask_matrix_nms(masks, labels, scores, nms_pre=500)
assert len(score) == len(label) == \
len(mask) == len(keep_ind) == 500
# use max_num
score, label, mask, keep_ind = \
mask_matrix_nms(masks, labels, scores,
nms_pre=500, max_num=100)
assert len(score) == len(label) == \
len(mask) == len(keep_ind) == 100
masks, labels, _ = _create_mask(1, 28, 28)
scores = torch.Tensor([1.0])
masks = masks.expand(1000, 28, 28)
labels = labels.expand(1000)
scores = scores.expand(1000)
# assert scores is decayed and update_thr is worked
# if with the same mask, label, and all scores = 1
# the first score will set to 1, others will decay.
score, label, mask, keep_ind = \
mask_matrix_nms(masks,
labels,
scores,
nms_pre=500,
max_num=100,
kernel='gaussian',
sigma=2.0,
filter_thr=0.5)
assert len(score) == 1
assert score[0] == 1
Loading…
Cancel
Save