Code for AAAI 2021 paper. "SCNet: Training Inference Sample Consistency for Instance Segmentation" (#4356)

* add scnet

* add more configs and minors

* fix unittest for scnet

* update docstring

* add forward test for htc and scnet

* fix build on pytorch 1.3 for empty tensor

* update doc string and minor refactor

* update unit-test and use inheritance for SCNet heads

* support scnet tta

* minor

* update readme

* minor fixes
pull/4546/head
Thang Vu 4 years ago committed by GitHub
parent 31809ece87
commit 40f168937d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      README.md
  2. 51
      configs/scnet/README.md
  3. 2
      configs/scnet/scnet_r101_fpn_20e_coco.py
  4. 136
      configs/scnet/scnet_r50_fpn_1x_coco.py
  5. 4
      configs/scnet/scnet_r50_fpn_20e_coco.py
  6. 14
      configs/scnet/scnet_x101_64x4d_fpn_20e_coco.py
  7. 3
      configs/scnet/scnet_x101_64x4d_fpn_8x1_20e_coco.py
  8. 4
      mmdet/models/detectors/__init__.py
  9. 10
      mmdet/models/detectors/scnet.py
  10. 14
      mmdet/models/roi_heads/__init__.py
  11. 4
      mmdet/models/roi_heads/bbox_heads/__init__.py
  12. 76
      mmdet/models/roi_heads/bbox_heads/scnet_bbox_head.py
  13. 7
      mmdet/models/roi_heads/mask_heads/__init__.py
  14. 55
      mmdet/models/roi_heads/mask_heads/feature_relay_head.py
  15. 101
      mmdet/models/roi_heads/mask_heads/global_context_head.py
  16. 27
      mmdet/models/roi_heads/mask_heads/scnet_mask_head.py
  17. 27
      mmdet/models/roi_heads/mask_heads/scnet_semantic_head.py
  18. 582
      mmdet/models/roi_heads/scnet_roi_head.py
  19. 4
      mmdet/models/utils/__init__.py
  20. 85
      mmdet/models/utils/res_layer.py
  21. 5
      tests/test_config.py
  22. 8
      tests/test_data/test_models_aug_test.py
  23. 58
      tests/test_models/test_backbones.py
  24. 54
      tests/test_models/test_forward.py

@ -108,6 +108,7 @@ Supported methods:
- [x] [VFNet](configs/vfnet/README.md)
- [x] [DETR](configs/detr/README.md)
- [x] [CascadeRPN](configs/cascade_rpn/README.md)
- [x] [SCNet](configs/scnet/README.md)
Some other methods are also supported in [projects using MMDetection](./docs/projects.md).

@ -0,0 +1,51 @@
# SCNet
## Introduction
[ALGORITHM]
We provide the code for reproducing experiment results of [SCNet](https://arxiv.org/abs/2012.10150).
```
@inproceedings{vu2019cascade,
title={SCNet: Training Inference Sample Consistency for Instance Segmentation},
author={Vu, Thang and Haeyong, Kang and Yoo, Chang D},
booktitle={AAAI},
year={2021}
}
```
## Dataset
SCNet requires COCO and COCO-stuff dataset for training. You need to download and extract it in the COCO dataset path.
The directory should be like this.
```none
mmdetection
├── mmdet
├── tools
├── configs
├── data
│ ├── coco
│ │ ├── annotations
│ │ ├── train2017
│ │ ├── val2017
│ │ ├── test2017
| | ├── stuffthingmaps
```
## Results and Models
The results on COCO 2017val are shown in the below table. (results on test-dev are usually slightly higher than val)
| Backbone | Style | Lr schd | Mem (GB) | Inf speed (fps) | box AP | mask AP | TTA box AP | TTA mask AP | Config | Download |
|:---------------:|:-------:|:-------:|:--------:|:---------------:|:------:|:-------:|:----------:|:-----------:|:------:|:------------:|
| R-50-FPN | pytorch | 1x | 7.0 | 6.2 | 43.5 | 39.2 | 44.8 | 40.9 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_r50_fpn_1x_coco.py) | [model](https://drive.google.com/file/d/179pcG-sNVDglJoZcQLsM8GWZnJhZtgWx/view?usp=sharing) \| [log](https://drive.google.com/file/d/1ZFS6QhFfxlOnDYPiGpSDP_Fzgb7iDGN3/view?usp=sharing) |
| R-50-FPN | pytorch | 20e | 7.0 | 6.2 | 44.5 | 40.0 | 45.8 | 41.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_r50_fpn_20e_coco.py) | [model](https://drive.google.com/file/d/1hCH4raiUXsnYQYmKc4ADeL5pNc2X72fJ/view?usp=sharing) \| [log](https://drive.google.com/file/d/1-LnkOXN8n5ojQW34H0qZ625cgrnWpqSX/view?usp=sharing) |
| R-101-FPN | pytorch | 20e | 8.9 | 5.8 | 45.8 | 40.9 | 47.3 | 42.7 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_r101_fpn_20e_coco.py) | [model](https://drive.google.com/file/d/1LcnegOX9YYZ7G8hqUa1F_Mp6PeT6_3jA/view?usp=sharing) \| [log](https://drive.google.com/file/d/1iRx-9GRgTaIDsz-we3DGwFVH22nbvCLa/view?usp=sharing) |
| X-101-64x4d-FPN | pytorch | 20e | 13.2 | 4.9 | 47.5 | 42.3 | 48.9 | 44.0 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_x101_64x4d_fpn_20e_coco.py) | [model](https://drive.google.com/file/d/1MDIoOBKwfXUdtEduz2BQm3MSgC05IlqZ/view?usp=sharing) \| [log](https://drive.google.com/file/d/1OsfQJ8gwtqIQ61k358yxY21sCvbUcRjs/view?usp=sharing) |
### Notes
- Training hyper-parameters are identical to those of [HTC](https://github.com/open-mmlab/mmdetection/tree/master/configs/htc).
- TTA means Test Time Augmentation, which applies horizonal flip and multi-scale testing. Refer to [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_r50_fpn_1x_coco.py).

@ -0,0 +1,2 @@
_base_ = './scnet_r50_fpn_20e_coco.py'
model = dict(pretrained='torchvision://resnet101', backbone=dict(depth=101))

@ -0,0 +1,136 @@
_base_ = '../htc/htc_r50_fpn_1x_coco.py'
# model settings
model = dict(
type='SCNet',
roi_head=dict(
_delete_=True,
type='SCNetRoIHead',
num_stages=3,
stage_loss_weights=[1, 0.5, 0.25],
bbox_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
bbox_head=[
dict(
type='SCNetBBoxHead',
num_shared_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.1, 0.1, 0.2, 0.2]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='SCNetBBoxHead',
num_shared_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.05, 0.05, 0.1, 0.1]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
loss_weight=1.0)),
dict(
type='SCNetBBoxHead',
num_shared_fcs=2,
in_channels=256,
fc_out_channels=1024,
roi_feat_size=7,
num_classes=80,
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[0., 0., 0., 0.],
target_stds=[0.033, 0.033, 0.067, 0.067]),
reg_class_agnostic=True,
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0))
],
mask_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[4, 8, 16, 32]),
mask_head=dict(
type='SCNetMaskHead',
num_convs=12,
in_channels=256,
conv_out_channels=256,
num_classes=80,
conv_to_res=True,
loss_mask=dict(
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)),
semantic_roi_extractor=dict(
type='SingleRoIExtractor',
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
out_channels=256,
featmap_strides=[8]),
semantic_head=dict(
type='SCNetSemanticHead',
num_ins=5,
fusion_level=1,
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=183,
ignore_label=255,
loss_weight=0.2,
conv_to_res=True),
glbctx_head=dict(
type='GlobalContextHead',
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=81,
loss_weight=3.0,
conv_to_res=True),
feat_relay_head=dict(
type='FeatureRelayHead',
in_channels=1024,
out_conv_channels=256,
roi_feat_size=7,
scale_factor=2)))
# uncomment below code to enable test time augmentations
# img_norm_cfg = dict(
# mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
# test_pipeline = [
# dict(type='LoadImageFromFile'),
# dict(
# type='MultiScaleFlipAug',
# img_scale=[(600, 900), (800, 1200), (1000, 1500), (1200, 1800),
# (1400, 2100)],
# flip=True,
# transforms=[
# dict(type='Resize', keep_ratio=True),
# dict(type='RandomFlip', flip_ratio=0.5),
# dict(type='Normalize', **img_norm_cfg),
# dict(type='Pad', size_divisor=32),
# dict(type='ImageToTensor', keys=['img']),
# dict(type='Collect', keys=['img']),
# ])
# ]
# data = dict(
# val=dict(pipeline=test_pipeline),
# test=dict(pipeline=test_pipeline))

@ -0,0 +1,4 @@
_base_ = './scnet_r50_fpn_1x_coco.py'
# learning policy
lr_config = dict(step=[16, 19])
total_epochs = 20

@ -0,0 +1,14 @@
_base_ = './scnet_r50_fpn_20e_coco.py'
model = dict(
pretrained='open-mmlab://resnext101_64x4d',
backbone=dict(
type='ResNeXt',
depth=101,
groups=64,
base_width=4,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'))

@ -0,0 +1,3 @@
_base_ = './scnet_x101_64x4d_fpn_20e_coco.py'
data = dict(samples_per_gpu=1, workers_per_gpu=1)
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

@ -19,6 +19,7 @@ from .point_rend import PointRend
from .reppoints_detector import RepPointsDetector
from .retinanet import RetinaNet
from .rpn import RPN
from .scnet import SCNet
from .single_stage import SingleStageDetector
from .sparse_rcnn import SparseRCNN
from .trident_faster_rcnn import TridentFasterRCNN
@ -32,5 +33,6 @@ __all__ = [
'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector',
'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA',
'YOLOV3', 'YOLACT', 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN'
'YOLOV3', 'YOLACT', 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN',
'SCNet'
]

@ -0,0 +1,10 @@
from ..builder import DETECTORS
from .cascade_rcnn import CascadeRCNN
@DETECTORS.register_module()
class SCNet(CascadeRCNN):
"""Implementation of `SCNet <https://arxiv.org/abs/2012.10150>`_"""
def __init__(self, **kwargs):
super(SCNet, self).__init__(**kwargs)

@ -1,17 +1,21 @@
from .base_roi_head import BaseRoIHead
from .bbox_heads import (BBoxHead, ConvFCBBoxHead, DoubleConvFCBBoxHead,
Shared2FCBBoxHead, Shared4Conv1FCBBoxHead)
SCNetBBoxHead, Shared2FCBBoxHead,
Shared4Conv1FCBBoxHead)
from .cascade_roi_head import CascadeRoIHead
from .double_roi_head import DoubleHeadRoIHead
from .dynamic_roi_head import DynamicRoIHead
from .grid_roi_head import GridRoIHead
from .htc_roi_head import HybridTaskCascadeRoIHead
from .mask_heads import (CoarseMaskHead, FCNMaskHead, FusedSemanticHead,
GridHead, HTCMaskHead, MaskIoUHead, MaskPointHead)
from .mask_heads import (CoarseMaskHead, FCNMaskHead, FeatureRelayHead,
FusedSemanticHead, GlobalContextHead, GridHead,
HTCMaskHead, MaskIoUHead, MaskPointHead,
SCNetMaskHead, SCNetSemanticHead)
from .mask_scoring_roi_head import MaskScoringRoIHead
from .pisa_roi_head import PISARoIHead
from .point_rend_roi_head import PointRendRoIHead
from .roi_extractors import SingleRoIExtractor
from .scnet_roi_head import SCNetRoIHead
from .shared_heads import ResLayer
from .sparse_roi_head import SparseRoIHead
from .standard_roi_head import StandardRoIHead
@ -24,5 +28,7 @@ __all__ = [
'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'FCNMaskHead',
'HTCMaskHead', 'FusedSemanticHead', 'GridHead', 'MaskIoUHead',
'SingleRoIExtractor', 'PISARoIHead', 'PointRendRoIHead', 'MaskPointHead',
'CoarseMaskHead', 'DynamicRoIHead', 'SparseRoIHead', 'TridentRoIHead'
'CoarseMaskHead', 'DynamicRoIHead', 'SparseRoIHead', 'TridentRoIHead',
'SCNetRoIHead', 'SCNetMaskHead', 'SCNetSemanticHead', 'SCNetBBoxHead',
'FeatureRelayHead', 'GlobalContextHead'
]

@ -4,8 +4,10 @@ from .convfc_bbox_head import (ConvFCBBoxHead, Shared2FCBBoxHead,
from .dii_head import DIIHead
from .double_bbox_head import DoubleConvFCBBoxHead
from .sabl_head import SABLHead
from .scnet_bbox_head import SCNetBBoxHead
__all__ = [
'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead',
'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'SABLHead', 'DIIHead'
'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'SABLHead', 'DIIHead',
'SCNetBBoxHead'
]

@ -0,0 +1,76 @@
from mmdet.models.builder import HEADS
from .convfc_bbox_head import ConvFCBBoxHead
@HEADS.register_module()
class SCNetBBoxHead(ConvFCBBoxHead):
"""BBox head for `SCNet <https://arxiv.org/abs/2012.10150>`_.
This inherits ``ConvFCBBoxHead`` with modified forward() function, allow us
to get intermediate shared feature.
"""
def _forward_shared(self, x):
"""Forward function for shared part."""
if self.num_shared_convs > 0:
for conv in self.shared_convs:
x = conv(x)
if self.num_shared_fcs > 0:
if self.with_avg_pool:
x = self.avg_pool(x)
x = x.flatten(1)
for fc in self.shared_fcs:
x = self.relu(fc(x))
return x
def _forward_cls_reg(self, x):
"""Forward function for classification and regression parts."""
x_cls = x
x_reg = x
for conv in self.cls_convs:
x_cls = conv(x_cls)
if x_cls.dim() > 2:
if self.with_avg_pool:
x_cls = self.avg_pool(x_cls)
x_cls = x_cls.flatten(1)
for fc in self.cls_fcs:
x_cls = self.relu(fc(x_cls))
for conv in self.reg_convs:
x_reg = conv(x_reg)
if x_reg.dim() > 2:
if self.with_avg_pool:
x_reg = self.avg_pool(x_reg)
x_reg = x_reg.flatten(1)
for fc in self.reg_fcs:
x_reg = self.relu(fc(x_reg))
cls_score = self.fc_cls(x_cls) if self.with_cls else None
bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
return cls_score, bbox_pred
def forward(self, x, return_shared_feat=False):
"""Forward function.
Args:
x (Tensor): input features
return_shared_feat (bool): If True, return cls-reg-shared feature.
Return:
out (tuple[Tensor]): contain ``cls_score`` and ``bbox_pred``,
if ``return_shared_feat`` is True, append ``x_shared`` to the
returned tuple.
"""
x_shared = self._forward_shared(x)
out = self._forward_cls_reg(x_shared)
if return_shared_feat:
out += (x_shared, )
return out

@ -1,12 +1,17 @@
from .coarse_mask_head import CoarseMaskHead
from .fcn_mask_head import FCNMaskHead
from .feature_relay_head import FeatureRelayHead
from .fused_semantic_head import FusedSemanticHead
from .global_context_head import GlobalContextHead
from .grid_head import GridHead
from .htc_mask_head import HTCMaskHead
from .mask_point_head import MaskPointHead
from .maskiou_head import MaskIoUHead
from .scnet_mask_head import SCNetMaskHead
from .scnet_semantic_head import SCNetSemanticHead
__all__ = [
'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead',
'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead'
'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead', 'SCNetMaskHead',
'SCNetSemanticHead', 'GlobalContextHead', 'FeatureRelayHead'
]

@ -0,0 +1,55 @@
import torch.nn as nn
from mmcv.cnn import kaiming_init
from mmcv.runner import auto_fp16
from mmdet.models.builder import HEADS
@HEADS.register_module()
class FeatureRelayHead(nn.Module):
"""Feature Relay Head used in SCNet https://arxiv.org/abs/2012.10150.
Args:
in_channels (int, optional): number of input channels. Default: 256.
conv_out_channels (int, optional): number of output channels before
classification layer. Default: 256.
roi_feat_size (int, optional): roi feat size at box head. Default: 7.
scale_factor (int, optional): scale factor to match roi feat size
at mask head. Default: 2.
"""
def __init__(self,
in_channels=1024,
out_conv_channels=256,
roi_feat_size=7,
scale_factor=2):
super(FeatureRelayHead, self).__init__()
assert isinstance(roi_feat_size, int)
self.in_channels = in_channels
self.out_conv_channels = out_conv_channels
self.roi_feat_size = roi_feat_size
self.out_channels = (roi_feat_size**2) * out_conv_channels
self.scale_factor = scale_factor
self.fp16_enabled = False
self.fc = nn.Linear(self.in_channels, self.out_channels)
self.upsample = nn.Upsample(
scale_factor=scale_factor, mode='bilinear', align_corners=True)
def init_weights(self):
"""Init weights for the head."""
kaiming_init(self.fc)
@auto_fp16()
def forward(self, x):
"""Forward function."""
N, in_C = x.shape
if N > 0:
out_C = self.out_conv_channels
out_HW = self.roi_feat_size
x = self.fc(x)
x = x.reshape(N, out_C, out_HW, out_HW)
x = self.upsample(x)
return x
return None

@ -0,0 +1,101 @@
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import auto_fp16, force_fp32
from mmdet.models.builder import HEADS
from mmdet.models.utils import ResLayer, SimplifiedBasicBlock
@HEADS.register_module()
class GlobalContextHead(nn.Module):
"""Global context head used in SCNet https://arxiv.org/abs/2012.10150.
Args:
num_convs (int, optional): number of convolutional layer in GlbCtxHead.
Default: 4.
in_channels (int, optional): number of input channels. Default: 256.
conv_out_channels (int, optional): number of output channels before
classification layer. Default: 256.
loss_weight (float, optional): global context loss weight. Default: 1.
conv_cfg (dict, optional): config to init conv layer. Default: None.
norm_cfg (dict, optional): config to init norm layer. Default: None.
conv_to_res (bool, optional): if True, 2 convs will be grouped into
1 `SimplifiedBasicBlock` using a skip connection. Default: False.
"""
def __init__(self,
num_convs=4,
in_channels=256,
conv_out_channels=256,
num_classes=81,
loss_weight=1.0,
conv_cfg=None,
norm_cfg=None,
conv_to_res=False):
super(GlobalContextHead, self).__init__()
self.num_convs = num_convs
self.in_channels = in_channels
self.conv_out_channels = conv_out_channels
self.num_classes = num_classes
self.loss_weight = loss_weight
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.conv_to_res = conv_to_res
self.fp16_enabled = False
if self.conv_to_res:
num_res_blocks = num_convs // 2
self.convs = ResLayer(
SimplifiedBasicBlock,
in_channels,
self.conv_out_channels,
num_res_blocks,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
self.num_convs = num_res_blocks
else:
self.convs = nn.ModuleList()
for i in range(self.num_convs):
in_channels = self.in_channels if i == 0 else conv_out_channels
self.convs.append(
ConvModule(
in_channels,
conv_out_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg))
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(conv_out_channels, num_classes)
self.criterion = nn.BCEWithLogitsLoss()
def init_weights(self):
"""Init weights for the head."""
nn.init.normal_(self.fc.weight, 0, 0.01)
nn.init.constant_(self.fc.bias, 0)
@auto_fp16()
def forward(self, feats):
"""Forward function."""
x = feats[-1]
for i in range(self.num_convs):
x = self.convs[i](x)
x = self.pool(x)
# multi-class prediction
mc_pred = x.reshape(x.size(0), -1)
mc_pred = self.fc(mc_pred)
return mc_pred, x
@force_fp32(apply_to=('pred', ))
def loss(self, pred, labels):
"""Loss function."""
labels = [lbl.unique() for lbl in labels]
targets = pred.new_zeros(pred.size())
for i, label in enumerate(labels):
targets[i, label] = 1.0
loss = self.loss_weight * self.criterion(pred, targets)
return loss

@ -0,0 +1,27 @@
from mmdet.models.builder import HEADS
from mmdet.models.utils import ResLayer, SimplifiedBasicBlock
from .fcn_mask_head import FCNMaskHead
@HEADS.register_module()
class SCNetMaskHead(FCNMaskHead):
"""Mask head for `SCNet <https://arxiv.org/abs/2012.10150>`_.
Args:
conv_to_res (bool, optional): if True, change the conv layers to
``SimplifiedBasicBlock``.
"""
def __init__(self, conv_to_res=True, **kwargs):
super(SCNetMaskHead, self).__init__(**kwargs)
self.conv_to_res = conv_to_res
if conv_to_res:
assert self.conv_kernel_size == 3
self.num_res_blocks = self.num_convs // 2
self.convs = ResLayer(
SimplifiedBasicBlock,
self.in_channels,
self.conv_out_channels,
self.num_res_blocks,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)

@ -0,0 +1,27 @@
from mmdet.models.builder import HEADS
from mmdet.models.utils import ResLayer, SimplifiedBasicBlock
from .fused_semantic_head import FusedSemanticHead
@HEADS.register_module()
class SCNetSemanticHead(FusedSemanticHead):
"""Mask head for `SCNet <https://arxiv.org/abs/2012.10150>`_.
Args:
conv_to_res (bool, optional): if True, change the conv layers to
``SimplifiedBasicBlock``.
"""
def __init__(self, conv_to_res=True, **kwargs):
super(SCNetSemanticHead, self).__init__(**kwargs)
self.conv_to_res = conv_to_res
if self.conv_to_res:
num_res_blocks = self.num_convs // 2
self.convs = ResLayer(
SimplifiedBasicBlock,
self.in_channels,
self.conv_out_channels,
num_res_blocks,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg)
self.num_convs = num_res_blocks

@ -0,0 +1,582 @@
import torch
import torch.nn.functional as F
from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
merge_aug_masks, multiclass_nms)
from ..builder import HEADS, build_head, build_roi_extractor
from .cascade_roi_head import CascadeRoIHead
@HEADS.register_module()
class SCNetRoIHead(CascadeRoIHead):
"""RoIHead for `SCNet <https://arxiv.org/abs/2012.10150>`_.
Args:
num_stages (int): number of cascade stages.
stage_loss_weights (list): loss weight of cascade stages.
semantic_roi_extractor (dict): config to init semantic roi extractor.
semantic_head (dict): config to init semantic head.
feat_relay_head (dict): config to init feature_relay_head.
glbctx_head (dict): config to init global context head.
"""
def __init__(self,
num_stages,
stage_loss_weights,
semantic_roi_extractor=None,
semantic_head=None,
feat_relay_head=None,
glbctx_head=None,
**kwargs):
super(SCNetRoIHead, self).__init__(num_stages, stage_loss_weights,
**kwargs)
assert self.with_bbox and self.with_mask
assert not self.with_shared_head # shared head is not supported
if semantic_head is not None:
self.semantic_roi_extractor = build_roi_extractor(
semantic_roi_extractor)
self.semantic_head = build_head(semantic_head)
if feat_relay_head is not None:
self.feat_relay_head = build_head(feat_relay_head)
if glbctx_head is not None:
self.glbctx_head = build_head(glbctx_head)
def init_mask_head(self, mask_roi_extractor, mask_head):
"""Initialize ``mask_head``"""
if mask_roi_extractor is not None:
self.mask_roi_extractor = build_roi_extractor(mask_roi_extractor)
self.mask_head = build_head(mask_head)
def init_weights(self, pretrained):
"""Initialize the weights in head.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
for i in range(self.num_stages):
if self.with_bbox:
self.bbox_roi_extractor[i].init_weights()
self.bbox_head[i].init_weights()
if self.with_mask:
self.mask_roi_extractor.init_weights()
self.mask_head.init_weights()
if self.with_semantic:
self.semantic_head.init_weights()
if self.with_glbctx:
self.glbctx_head.init_weights()
if self.with_feat_relay:
self.feat_relay_head.init_weights()
@property
def with_semantic(self):
"""bool: whether the head has semantic head"""
return hasattr(self,
'semantic_head') and self.semantic_head is not None
@property
def with_feat_relay(self):
"""bool: whether the head has feature relay head"""
return (hasattr(self, 'feat_relay_head')
and self.feat_relay_head is not None)
@property
def with_glbctx(self):
"""bool: whether the head has global context head"""
return hasattr(self, 'glbctx_head') and self.glbctx_head is not None
def _fuse_glbctx(self, roi_feats, glbctx_feat, rois):
"""Fuse global context feats with roi feats."""
assert roi_feats.size(0) == rois.size(0)
img_inds = torch.unique(rois[:, 0].cpu(), sorted=True).long()
fused_feats = torch.zeros_like(roi_feats)
for img_id in img_inds:
inds = (rois[:, 0] == img_id.item())
fused_feats[inds] = roi_feats[inds] + glbctx_feat[img_id]
return fused_feats
def _slice_pos_feats(self, feats, sampling_results):
"""Get features from pos rois."""
num_rois = [res.bboxes.size(0) for res in sampling_results]
num_pos_rois = [res.pos_bboxes.size(0) for res in sampling_results]
inds = torch.zeros(sum(num_rois), dtype=torch.bool)
start = 0
for i in range(len(num_rois)):
start = 0 if i == 0 else start + num_rois[i - 1]
stop = start + num_pos_rois[i]
inds[start:stop] = 1
sliced_feats = feats[inds]
return sliced_feats
def _bbox_forward(self,
stage,
x,
rois,
semantic_feat=None,
glbctx_feat=None):
"""Box head forward function used in both training and testing."""
bbox_roi_extractor = self.bbox_roi_extractor[stage]
bbox_head = self.bbox_head[stage]
bbox_feats = bbox_roi_extractor(
x[:len(bbox_roi_extractor.featmap_strides)], rois)
if self.with_semantic and semantic_feat is not None:
bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat],
rois)
if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]:
bbox_semantic_feat = F.adaptive_avg_pool2d(
bbox_semantic_feat, bbox_feats.shape[-2:])
bbox_feats += bbox_semantic_feat
if self.with_glbctx and glbctx_feat is not None:
bbox_feats = self._fuse_glbctx(bbox_feats, glbctx_feat, rois)
cls_score, bbox_pred, relayed_feat = bbox_head(
bbox_feats, return_shared_feat=True)
bbox_results = dict(
cls_score=cls_score,
bbox_pred=bbox_pred,
relayed_feat=relayed_feat)
return bbox_results
def _mask_forward(self,
x,
rois,
semantic_feat=None,
glbctx_feat=None,
relayed_feat=None):
"""Mask head forward function used in both training and testing."""
mask_feats = self.mask_roi_extractor(
x[:self.mask_roi_extractor.num_inputs], rois)
if self.with_semantic and semantic_feat is not None:
mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
rois)
if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
mask_semantic_feat = F.adaptive_avg_pool2d(
mask_semantic_feat, mask_feats.shape[-2:])
mask_feats += mask_semantic_feat
if self.with_glbctx and glbctx_feat is not None:
mask_feats = self._fuse_glbctx(mask_feats, glbctx_feat, rois)
if self.with_feat_relay and relayed_feat is not None:
mask_feats = mask_feats + relayed_feat
mask_pred = self.mask_head(mask_feats)
mask_results = dict(mask_pred=mask_pred)
return mask_results
def _bbox_forward_train(self,
stage,
x,
sampling_results,
gt_bboxes,
gt_labels,
rcnn_train_cfg,
semantic_feat=None,
glbctx_feat=None):
"""Run forward function and calculate loss for box head in training."""
bbox_head = self.bbox_head[stage]
rois = bbox2roi([res.bboxes for res in sampling_results])
bbox_results = self._bbox_forward(
stage,
x,
rois,
semantic_feat=semantic_feat,
glbctx_feat=glbctx_feat)
bbox_targets = bbox_head.get_targets(sampling_results, gt_bboxes,
gt_labels, rcnn_train_cfg)
loss_bbox = bbox_head.loss(bbox_results['cls_score'],
bbox_results['bbox_pred'], rois,
*bbox_targets)
bbox_results.update(
loss_bbox=loss_bbox, rois=rois, bbox_targets=bbox_targets)
return bbox_results
def _mask_forward_train(self,
x,
sampling_results,
gt_masks,
rcnn_train_cfg,
semantic_feat=None,
glbctx_feat=None,
relayed_feat=None):
"""Run forward function and calculate loss for mask head in
training."""
pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
mask_results = self._mask_forward(
x,
pos_rois,
semantic_feat=semantic_feat,
glbctx_feat=glbctx_feat,
relayed_feat=relayed_feat)
mask_targets = self.mask_head.get_targets(sampling_results, gt_masks,
rcnn_train_cfg)
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
loss_mask = self.mask_head.loss(mask_results['mask_pred'],
mask_targets, pos_labels)
mask_results = loss_mask
return mask_results
def forward_train(self,
x,
img_metas,
proposal_list,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None,
gt_masks=None,
gt_semantic_seg=None):
"""
Args:
x (list[Tensor]): list of multi-level img features.
img_metas (list[dict]): list of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmdet/datasets/pipelines/formatting.py:Collect`.
proposal_list (list[Tensors]): list of region proposals.
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): class indices corresponding to each box
gt_bboxes_ignore (None, list[Tensor]): specify which bounding
boxes can be ignored when computing the loss.
gt_masks (None, Tensor) : true segmentation masks for each box
used if the architecture supports a segmentation task.
gt_semantic_seg (None, list[Tensor]): semantic segmentation masks
used if the architecture supports semantic segmentation task.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
losses = dict()
# semantic segmentation branch
if self.with_semantic:
semantic_pred, semantic_feat = self.semantic_head(x)
loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_seg)
losses['loss_semantic_seg'] = loss_seg
else:
semantic_feat = None
# global context branch
if self.with_glbctx:
mc_pred, glbctx_feat = self.glbctx_head(x)
loss_glbctx = self.glbctx_head.loss(mc_pred, gt_labels)
losses['loss_glbctx'] = loss_glbctx
else:
glbctx_feat = None
for i in range(self.num_stages):
self.current_stage = i
rcnn_train_cfg = self.train_cfg[i]
lw = self.stage_loss_weights[i]
# assign gts and sample proposals
sampling_results = []
bbox_assigner = self.bbox_assigner[i]
bbox_sampler = self.bbox_sampler[i]
num_imgs = len(img_metas)
if gt_bboxes_ignore is None:
gt_bboxes_ignore = [None for _ in range(num_imgs)]
for j in range(num_imgs):
assign_result = bbox_assigner.assign(proposal_list[j],
gt_bboxes[j],
gt_bboxes_ignore[j],
gt_labels[j])
sampling_result = bbox_sampler.sample(
assign_result,
proposal_list[j],
gt_bboxes[j],
gt_labels[j],
feats=[lvl_feat[j][None] for lvl_feat in x])
sampling_results.append(sampling_result)
bbox_results = \
self._bbox_forward_train(
i, x, sampling_results, gt_bboxes, gt_labels,
rcnn_train_cfg, semantic_feat, glbctx_feat)
roi_labels = bbox_results['bbox_targets'][0]
for name, value in bbox_results['loss_bbox'].items():
losses[f's{i}.{name}'] = (
value * lw if 'loss' in name else value)
# refine boxes
if i < self.num_stages - 1:
pos_is_gts = [res.pos_is_gt for res in sampling_results]
with torch.no_grad():
proposal_list = self.bbox_head[i].refine_bboxes(
bbox_results['rois'], roi_labels,
bbox_results['bbox_pred'], pos_is_gts, img_metas)
if self.with_feat_relay:
relayed_feat = self._slice_pos_feats(bbox_results['relayed_feat'],
sampling_results)
relayed_feat = self.feat_relay_head(relayed_feat)
else:
relayed_feat = None
mask_results = self._mask_forward_train(x, sampling_results, gt_masks,
rcnn_train_cfg, semantic_feat,
glbctx_feat, relayed_feat)
mask_lw = sum(self.stage_loss_weights)
losses['loss_mask'] = mask_lw * mask_results['loss_mask']
return losses
def simple_test(self, x, proposal_list, img_metas, rescale=False):
"""Test without augmentation."""
if self.with_semantic:
_, semantic_feat = self.semantic_head(x)
else:
semantic_feat = None
if self.with_glbctx:
mc_pred, glbctx_feat = self.glbctx_head(x)
else:
glbctx_feat = None
num_imgs = len(proposal_list)
img_shapes = tuple(meta['img_shape'] for meta in img_metas)
ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
# "ms" in variable names means multi-stage
ms_scores = []
rcnn_test_cfg = self.test_cfg
rois = bbox2roi(proposal_list)
for i in range(self.num_stages):
bbox_head = self.bbox_head[i]
bbox_results = self._bbox_forward(
i,
x,
rois,
semantic_feat=semantic_feat,
glbctx_feat=glbctx_feat)
# split batch bbox prediction back to each image
cls_score = bbox_results['cls_score']
bbox_pred = bbox_results['bbox_pred']
num_proposals_per_img = tuple(len(p) for p in proposal_list)
rois = rois.split(num_proposals_per_img, 0)
cls_score = cls_score.split(num_proposals_per_img, 0)
bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
ms_scores.append(cls_score)
if i < self.num_stages - 1:
bbox_label = [s[:, :-1].argmax(dim=1) for s in cls_score]
rois = torch.cat([
bbox_head.regress_by_class(rois[i], bbox_label[i],
bbox_pred[i], img_metas[i])
for i in range(num_imgs)
])
# average scores of each image by stages
cls_score = [
sum([score[i] for score in ms_scores]) / float(len(ms_scores))
for i in range(num_imgs)
]
# apply bbox post-processing to each image individually
det_bboxes = []
det_labels = []
for i in range(num_imgs):
det_bbox, det_label = self.bbox_head[-1].get_bboxes(
rois[i],
cls_score[i],
bbox_pred[i],
img_shapes[i],
scale_factors[i],
rescale=rescale,
cfg=rcnn_test_cfg)
det_bboxes.append(det_bbox)
det_labels.append(det_label)
det_bbox_results = [
bbox2result(det_bboxes[i], det_labels[i],
self.bbox_head[-1].num_classes)
for i in range(num_imgs)
]
if self.with_mask:
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
mask_classes = self.mask_head.num_classes
det_segm_results = [[[] for _ in range(mask_classes)]
for _ in range(num_imgs)]
else:
if rescale and not isinstance(scale_factors[0], float):
scale_factors = [
torch.from_numpy(scale_factor).to(det_bboxes[0].device)
for scale_factor in scale_factors
]
_bboxes = [
det_bboxes[i][:, :4] *
scale_factors[i] if rescale else det_bboxes[i]
for i in range(num_imgs)
]
mask_rois = bbox2roi(_bboxes)
# get relay feature on mask_rois
bbox_results = self._bbox_forward(
-1,
x,
mask_rois,
semantic_feat=semantic_feat,
glbctx_feat=glbctx_feat)
relayed_feat = bbox_results['relayed_feat']
relayed_feat = self.feat_relay_head(relayed_feat)
mask_results = self._mask_forward(
x,
mask_rois,
semantic_feat=semantic_feat,
glbctx_feat=glbctx_feat,
relayed_feat=relayed_feat)
mask_pred = mask_results['mask_pred']
# split batch mask prediction back to each image
num_bbox_per_img = tuple(len(_bbox) for _bbox in _bboxes)
mask_preds = mask_pred.split(num_bbox_per_img, 0)
# apply mask post-processing to each image individually
det_segm_results = []
for i in range(num_imgs):
if det_bboxes[i].shape[0] == 0:
det_segm_results.append(
[[] for _ in range(self.mask_head.num_classes)])
else:
segm_result = self.mask_head.get_seg_masks(
mask_preds[i], _bboxes[i], det_labels[i],
self.test_cfg, ori_shapes[i], scale_factors[i],
rescale)
det_segm_results.append(segm_result)
# return results
if self.with_mask:
return list(zip(det_bbox_results, det_segm_results))
else:
return det_bbox_results
def aug_test(self, img_feats, proposal_list, img_metas, rescale=False):
if self.with_semantic:
semantic_feats = [
self.semantic_head(feat)[1] for feat in img_feats
]
else:
semantic_feats = [None] * len(img_metas)
if self.with_glbctx:
glbctx_feats = [self.glbctx_head(feat)[1] for feat in img_feats]
else:
glbctx_feats = [None] * len(img_metas)
rcnn_test_cfg = self.test_cfg
aug_bboxes = []
aug_scores = []
for x, img_meta, semantic_feat, glbctx_feat in zip(
img_feats, img_metas, semantic_feats, glbctx_feats):
# only one image in the batch
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
scale_factor, flip)
# "ms" in variable names means multi-stage
ms_scores = []
rois = bbox2roi([proposals])
for i in range(self.num_stages):
bbox_head = self.bbox_head[i]
bbox_results = self._bbox_forward(
i,
x,
rois,
semantic_feat=semantic_feat,
glbctx_feat=glbctx_feat)
ms_scores.append(bbox_results['cls_score'])
if i < self.num_stages - 1:
bbox_label = bbox_results['cls_score'].argmax(dim=1)
rois = bbox_head.regress_by_class(
rois, bbox_label, bbox_results['bbox_pred'],
img_meta[0])
cls_score = sum(ms_scores) / float(len(ms_scores))
bboxes, scores = self.bbox_head[-1].get_bboxes(
rois,
cls_score,
bbox_results['bbox_pred'],
img_shape,
scale_factor,
rescale=False,
cfg=None)
aug_bboxes.append(bboxes)
aug_scores.append(scores)
# after merging, bboxes will be rescaled to the original image size
merged_bboxes, merged_scores = merge_aug_bboxes(
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
rcnn_test_cfg.score_thr,
rcnn_test_cfg.nms,
rcnn_test_cfg.max_per_img)
det_bbox_results = bbox2result(det_bboxes, det_labels,
self.bbox_head[-1].num_classes)
if self.with_mask:
if det_bboxes.shape[0] == 0:
det_segm_results = [[]
for _ in range(self.mask_head.num_classes)]
else:
aug_masks = []
for x, img_meta, semantic_feat, glbctx_feat in zip(
img_feats, img_metas, semantic_feats, glbctx_feats):
img_shape = img_meta[0]['img_shape']
scale_factor = img_meta[0]['scale_factor']
flip = img_meta[0]['flip']
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
scale_factor, flip)
mask_rois = bbox2roi([_bboxes])
# get relay feature on mask_rois
bbox_results = self._bbox_forward(
-1,
x,
mask_rois,
semantic_feat=semantic_feat,
glbctx_feat=glbctx_feat)
relayed_feat = bbox_results['relayed_feat']
relayed_feat = self.feat_relay_head(relayed_feat)
mask_results = self._mask_forward(
x,
mask_rois,
semantic_feat=semantic_feat,
glbctx_feat=glbctx_feat,
relayed_feat=relayed_feat)
mask_pred = mask_results['mask_pred']
aug_masks.append(mask_pred.sigmoid().cpu().numpy())
merged_masks = merge_aug_masks(aug_masks, img_metas,
self.test_cfg)
ori_shape = img_metas[0][0]['ori_shape']
det_segm_results = self.mask_head.get_seg_masks(
merged_masks,
det_bboxes,
det_labels,
rcnn_test_cfg,
ori_shape,
scale_factor=1.0,
rescale=False)
return [(det_bbox_results, det_segm_results)]
else:
return [det_bbox_results]

@ -2,7 +2,7 @@ from .builder import build_positional_encoding, build_transformer
from .gaussian_target import gaussian_radius, gen_gaussian_target
from .positional_encoding import (LearnedPositionalEncoding,
SinePositionalEncoding)
from .res_layer import ResLayer
from .res_layer import ResLayer, SimplifiedBasicBlock
from .transformer import (FFN, DynamicConv, MultiheadAttention, Transformer,
TransformerDecoder, TransformerDecoderLayer,
TransformerEncoder, TransformerEncoderLayer)
@ -12,5 +12,5 @@ __all__ = [
'FFN', 'TransformerEncoderLayer', 'TransformerEncoder',
'TransformerDecoderLayer', 'TransformerDecoder', 'Transformer',
'build_transformer', 'build_positional_encoding', 'SinePositionalEncoding',
'LearnedPositionalEncoding', 'DynamicConv'
'LearnedPositionalEncoding', 'DynamicConv', 'SimplifiedBasicBlock'
]

@ -100,3 +100,88 @@ class ResLayer(nn.Sequential):
norm_cfg=norm_cfg,
**kwargs))
super(ResLayer, self).__init__(*layers)
class SimplifiedBasicBlock(nn.Module):
"""Simplified version of original basic residual block. This is used in
`SCNet <https://arxiv.org/abs/2012.10150>`_.
- Norm layer is now optional
- Last ReLU in forward function is removed
"""
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None):
super(SimplifiedBasicBlock, self).__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
assert not with_cp, 'Not implemented yet.'
self.with_norm = norm_cfg is not None
with_bias = True if norm_cfg is None else False
self.conv1 = build_conv_layer(
conv_cfg,
inplanes,
planes,
3,
stride=stride,
padding=dilation,
dilation=dilation,
bias=with_bias)
if self.with_norm:
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, planes, postfix=1)
self.add_module(self.norm1_name, norm1)
self.conv2 = build_conv_layer(
conv_cfg, planes, planes, 3, padding=1, bias=with_bias)
if self.with_norm:
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, planes, postfix=2)
self.add_module(self.norm2_name, norm2)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.with_cp = with_cp
@property
def norm1(self):
"""nn.Module: normalization layer after the first convolution layer"""
return getattr(self, self.norm1_name) if self.with_norm else None
@property
def norm2(self):
"""nn.Module: normalization layer after the second convolution layer"""
return getattr(self, self.norm2_name) if self.with_norm else None
def forward(self, x):
"""Forward function."""
identity = x
out = self.conv1(x)
if self.with_norm:
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
if self.with_norm:
out = self.norm2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out

@ -167,10 +167,11 @@ def _check_roi_head(config, head):
def _check_roi_extractor(config, roi_extractor, prev_roi_extractor=None):
import torch.nn as nn
# Separate roi_extractor and prev_roi_extractor checks for flexibility
if isinstance(roi_extractor, nn.ModuleList):
if prev_roi_extractor:
prev_roi_extractor = prev_roi_extractor[0]
roi_extractor = roi_extractor[0]
if prev_roi_extractor and isinstance(prev_roi_extractor, nn.ModuleList):
prev_roi_extractor = prev_roi_extractor[0]
assert (len(config.featmap_strides) == len(roi_extractor.roi_layers))
assert (config.out_channels == roi_extractor.out_channels)

@ -87,6 +87,14 @@ def test_htc_aug_test():
assert len(aug_result[0][1]) == 80
def test_scnet_aug_test():
aug_result = model_aug_test_template(
'configs/scnet/scnet_r50_fpn_1x_coco.py')
assert len(aug_result[0]) == 2
assert len(aug_result[0][0]) == 80
assert len(aug_result[0][1]) == 80
def test_cornernet_aug_test():
# get config
cfg = mmcv.Config.fromfile(

@ -12,12 +12,13 @@ from mmdet.models.backbones.resnest import Bottleneck as BottleneckS
from mmdet.models.backbones.resnet import BasicBlock, Bottleneck
from mmdet.models.backbones.resnext import Bottleneck as BottleneckX
from mmdet.models.backbones.trident_resnet import TridentBottleneck
from mmdet.models.utils import ResLayer
from mmdet.models.utils import ResLayer, SimplifiedBasicBlock
def is_block(modules):
"""Check if is ResNet building block."""
if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX, Bottle2neck)):
if isinstance(modules, (BasicBlock, Bottleneck, BottleneckX, Bottle2neck,
SimplifiedBasicBlock)):
return True
return False
@ -381,6 +382,59 @@ def test_trident_resnet_bottleneck():
assert x_out.shape == torch.Size([block.num_branch, 64, 56, 56])
def test_simplied_basic_block():
with pytest.raises(AssertionError):
# Not implemented yet.
dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False)
SimplifiedBasicBlock(64, 64, dcn=dcn)
with pytest.raises(AssertionError):
# Not implemented yet.
plugins = [
dict(
cfg=dict(type='ContextBlock', ratio=1. / 16),
position='after_conv3')
]
SimplifiedBasicBlock(64, 64, plugins=plugins)
with pytest.raises(AssertionError):
# Not implemented yet
plugins = [
dict(
cfg=dict(
type='GeneralizedAttention',
spatial_range=-1,
num_heads=8,
attention_type='0010',
kv_stride=2),
position='after_conv2')
]
SimplifiedBasicBlock(64, 64, plugins=plugins)
with pytest.raises(AssertionError):
# Not implemented yet
SimplifiedBasicBlock(64, 64, with_cp=True)
# test SimplifiedBasicBlock structure and forward
block = SimplifiedBasicBlock(64, 64)
assert block.conv1.in_channels == 64
assert block.conv1.out_channels == 64
assert block.conv1.kernel_size == (3, 3)
assert block.conv2.in_channels == 64
assert block.conv2.out_channels == 64
assert block.conv2.kernel_size == (3, 3)
x = torch.randn(1, 64, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
# test SimplifiedBasicBlock without norm
block = SimplifiedBasicBlock(64, 64, norm_cfg=None)
assert block.norm1 is None
assert block.norm2 is None
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 56, 56])
def test_trident_resnet_backbone():
tridentresnet_config = dict(
num_branch=3,

@ -227,14 +227,24 @@ def test_faster_rcnn_ohem_forward():
assert float(loss.item()) > 0
# HTC is not ready yet
@pytest.mark.parametrize('cfg_file', [
'cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py',
'mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py',
'grid_rcnn/grid_rcnn_r50_fpn_gn-head_2x_coco.py',
'ms_rcnn/ms_rcnn_r50_fpn_1x_coco.py'
'ms_rcnn/ms_rcnn_r50_fpn_1x_coco.py',
'htc/htc_r50_fpn_1x_coco.py',
'scnet/scnet_r50_fpn_20e_coco.py',
])
def test_two_stage_forward(cfg_file):
models_with_semantic = [
'htc/htc_r50_fpn_1x_coco.py',
'scnet/scnet_r50_fpn_20e_coco.py',
]
if cfg_file in models_with_semantic:
with_semantic = True
else:
with_semantic = False
model = _get_detector_cfg(cfg_file)
model['pretrained'] = None
@ -244,19 +254,11 @@ def test_two_stage_forward(cfg_file):
input_shape = (1, 3, 256, 256)
# Test forward train with a non-empty truth batch
mm_inputs = _demo_mm_inputs(input_shape, num_items=[10])
mm_inputs = _demo_mm_inputs(
input_shape, num_items=[10], with_semantic=with_semantic)
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
gt_bboxes = mm_inputs['gt_bboxes']
gt_labels = mm_inputs['gt_labels']
gt_masks = mm_inputs['gt_masks']
losses = detector.forward(
imgs,
img_metas,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
gt_masks=gt_masks,
return_loss=True)
losses = detector.forward(imgs, img_metas, return_loss=True, **mm_inputs)
assert isinstance(losses, dict)
loss, _ = detector._parse_losses(losses)
loss.requires_grad_(True)
@ -264,19 +266,11 @@ def test_two_stage_forward(cfg_file):
loss.backward()
# Test forward train with an empty truth batch
mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
mm_inputs = _demo_mm_inputs(
input_shape, num_items=[0], with_semantic=with_semantic)
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
gt_bboxes = mm_inputs['gt_bboxes']
gt_labels = mm_inputs['gt_labels']
gt_masks = mm_inputs['gt_masks']
losses = detector.forward(
imgs,
img_metas,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
gt_masks=gt_masks,
return_loss=True)
losses = detector.forward(imgs, img_metas, return_loss=True, **mm_inputs)
assert isinstance(losses, dict)
loss, _ = detector._parse_losses(losses)
loss.requires_grad_(True)
@ -330,7 +324,8 @@ def test_single_stage_forward_cpu(cfg_file):
def _demo_mm_inputs(input_shape=(1, 3, 300, 300),
num_items=None, num_classes=10): # yapf: disable
num_items=None, num_classes=10,
with_semantic=False): # yapf: disable
"""Create a superset of inputs needed to run test or train batches.
Args:
@ -358,6 +353,7 @@ def _demo_mm_inputs(input_shape=(1, 3, 300, 300),
'filename': '<demo>.png',
'scale_factor': 1.0,
'flip': False,
'flip_direction': None,
} for _ in range(N)]
gt_bboxes = []
@ -394,6 +390,14 @@ def _demo_mm_inputs(input_shape=(1, 3, 300, 300),
'gt_bboxes_ignore': None,
'gt_masks': gt_masks,
}
if with_semantic:
# assume gt_semantic_seg using scale 1/8 of the img
gt_semantic_seg = np.random.randint(
0, num_classes, (1, 1, H // 8, W // 8), dtype=np.uint8)
mm_inputs.update(
{'gt_semantic_seg': torch.ByteTensor(gt_semantic_seg)})
return mm_inputs

Loading…
Cancel
Save