Code for AAAI 2021 paper. "SCNet: Training Inference Sample Consistency for Instance Segmentation" (#4356)
* add scnet * add more configs and minors * fix unittest for scnet * update docstring * add forward test for htc and scnet * fix build on pytorch 1.3 for empty tensor * update doc string and minor refactor * update unit-test and use inheritance for SCNet heads * support scnet tta * minor * update readme * minor fixespull/4546/head
parent
31809ece87
commit
40f168937d
24 changed files with 1294 additions and 38 deletions
@ -0,0 +1,51 @@ |
|||||||
|
# SCNet |
||||||
|
|
||||||
|
## Introduction |
||||||
|
|
||||||
|
[ALGORITHM] |
||||||
|
|
||||||
|
We provide the code for reproducing experiment results of [SCNet](https://arxiv.org/abs/2012.10150). |
||||||
|
|
||||||
|
``` |
||||||
|
@inproceedings{vu2019cascade, |
||||||
|
title={SCNet: Training Inference Sample Consistency for Instance Segmentation}, |
||||||
|
author={Vu, Thang and Haeyong, Kang and Yoo, Chang D}, |
||||||
|
booktitle={AAAI}, |
||||||
|
year={2021} |
||||||
|
} |
||||||
|
``` |
||||||
|
|
||||||
|
## Dataset |
||||||
|
|
||||||
|
SCNet requires COCO and COCO-stuff dataset for training. You need to download and extract it in the COCO dataset path. |
||||||
|
The directory should be like this. |
||||||
|
|
||||||
|
```none |
||||||
|
mmdetection |
||||||
|
├── mmdet |
||||||
|
├── tools |
||||||
|
├── configs |
||||||
|
├── data |
||||||
|
│ ├── coco |
||||||
|
│ │ ├── annotations |
||||||
|
│ │ ├── train2017 |
||||||
|
│ │ ├── val2017 |
||||||
|
│ │ ├── test2017 |
||||||
|
| | ├── stuffthingmaps |
||||||
|
``` |
||||||
|
|
||||||
|
## Results and Models |
||||||
|
|
||||||
|
The results on COCO 2017val are shown in the below table. (results on test-dev are usually slightly higher than val) |
||||||
|
|
||||||
|
| Backbone | Style | Lr schd | Mem (GB) | Inf speed (fps) | box AP | mask AP | TTA box AP | TTA mask AP | Config | Download | |
||||||
|
|:---------------:|:-------:|:-------:|:--------:|:---------------:|:------:|:-------:|:----------:|:-----------:|:------:|:------------:| |
||||||
|
| R-50-FPN | pytorch | 1x | 7.0 | 6.2 | 43.5 | 39.2 | 44.8 | 40.9 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_r50_fpn_1x_coco.py) | [model](https://drive.google.com/file/d/179pcG-sNVDglJoZcQLsM8GWZnJhZtgWx/view?usp=sharing) \| [log](https://drive.google.com/file/d/1ZFS6QhFfxlOnDYPiGpSDP_Fzgb7iDGN3/view?usp=sharing) | |
||||||
|
| R-50-FPN | pytorch | 20e | 7.0 | 6.2 | 44.5 | 40.0 | 45.8 | 41.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_r50_fpn_20e_coco.py) | [model](https://drive.google.com/file/d/1hCH4raiUXsnYQYmKc4ADeL5pNc2X72fJ/view?usp=sharing) \| [log](https://drive.google.com/file/d/1-LnkOXN8n5ojQW34H0qZ625cgrnWpqSX/view?usp=sharing) | |
||||||
|
| R-101-FPN | pytorch | 20e | 8.9 | 5.8 | 45.8 | 40.9 | 47.3 | 42.7 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_r101_fpn_20e_coco.py) | [model](https://drive.google.com/file/d/1LcnegOX9YYZ7G8hqUa1F_Mp6PeT6_3jA/view?usp=sharing) \| [log](https://drive.google.com/file/d/1iRx-9GRgTaIDsz-we3DGwFVH22nbvCLa/view?usp=sharing) | |
||||||
|
| X-101-64x4d-FPN | pytorch | 20e | 13.2 | 4.9 | 47.5 | 42.3 | 48.9 | 44.0 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_x101_64x4d_fpn_20e_coco.py) | [model](https://drive.google.com/file/d/1MDIoOBKwfXUdtEduz2BQm3MSgC05IlqZ/view?usp=sharing) \| [log](https://drive.google.com/file/d/1OsfQJ8gwtqIQ61k358yxY21sCvbUcRjs/view?usp=sharing) | |
||||||
|
|
||||||
|
### Notes |
||||||
|
|
||||||
|
- Training hyper-parameters are identical to those of [HTC](https://github.com/open-mmlab/mmdetection/tree/master/configs/htc). |
||||||
|
- TTA means Test Time Augmentation, which applies horizonal flip and multi-scale testing. Refer to [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_r50_fpn_1x_coco.py). |
@ -0,0 +1,2 @@ |
|||||||
|
_base_ = './scnet_r50_fpn_20e_coco.py' |
||||||
|
model = dict(pretrained='torchvision://resnet101', backbone=dict(depth=101)) |
@ -0,0 +1,136 @@ |
|||||||
|
_base_ = '../htc/htc_r50_fpn_1x_coco.py' |
||||||
|
# model settings |
||||||
|
model = dict( |
||||||
|
type='SCNet', |
||||||
|
roi_head=dict( |
||||||
|
_delete_=True, |
||||||
|
type='SCNetRoIHead', |
||||||
|
num_stages=3, |
||||||
|
stage_loss_weights=[1, 0.5, 0.25], |
||||||
|
bbox_roi_extractor=dict( |
||||||
|
type='SingleRoIExtractor', |
||||||
|
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), |
||||||
|
out_channels=256, |
||||||
|
featmap_strides=[4, 8, 16, 32]), |
||||||
|
bbox_head=[ |
||||||
|
dict( |
||||||
|
type='SCNetBBoxHead', |
||||||
|
num_shared_fcs=2, |
||||||
|
in_channels=256, |
||||||
|
fc_out_channels=1024, |
||||||
|
roi_feat_size=7, |
||||||
|
num_classes=80, |
||||||
|
bbox_coder=dict( |
||||||
|
type='DeltaXYWHBBoxCoder', |
||||||
|
target_means=[0., 0., 0., 0.], |
||||||
|
target_stds=[0.1, 0.1, 0.2, 0.2]), |
||||||
|
reg_class_agnostic=True, |
||||||
|
loss_cls=dict( |
||||||
|
type='CrossEntropyLoss', |
||||||
|
use_sigmoid=False, |
||||||
|
loss_weight=1.0), |
||||||
|
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, |
||||||
|
loss_weight=1.0)), |
||||||
|
dict( |
||||||
|
type='SCNetBBoxHead', |
||||||
|
num_shared_fcs=2, |
||||||
|
in_channels=256, |
||||||
|
fc_out_channels=1024, |
||||||
|
roi_feat_size=7, |
||||||
|
num_classes=80, |
||||||
|
bbox_coder=dict( |
||||||
|
type='DeltaXYWHBBoxCoder', |
||||||
|
target_means=[0., 0., 0., 0.], |
||||||
|
target_stds=[0.05, 0.05, 0.1, 0.1]), |
||||||
|
reg_class_agnostic=True, |
||||||
|
loss_cls=dict( |
||||||
|
type='CrossEntropyLoss', |
||||||
|
use_sigmoid=False, |
||||||
|
loss_weight=1.0), |
||||||
|
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, |
||||||
|
loss_weight=1.0)), |
||||||
|
dict( |
||||||
|
type='SCNetBBoxHead', |
||||||
|
num_shared_fcs=2, |
||||||
|
in_channels=256, |
||||||
|
fc_out_channels=1024, |
||||||
|
roi_feat_size=7, |
||||||
|
num_classes=80, |
||||||
|
bbox_coder=dict( |
||||||
|
type='DeltaXYWHBBoxCoder', |
||||||
|
target_means=[0., 0., 0., 0.], |
||||||
|
target_stds=[0.033, 0.033, 0.067, 0.067]), |
||||||
|
reg_class_agnostic=True, |
||||||
|
loss_cls=dict( |
||||||
|
type='CrossEntropyLoss', |
||||||
|
use_sigmoid=False, |
||||||
|
loss_weight=1.0), |
||||||
|
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) |
||||||
|
], |
||||||
|
mask_roi_extractor=dict( |
||||||
|
type='SingleRoIExtractor', |
||||||
|
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), |
||||||
|
out_channels=256, |
||||||
|
featmap_strides=[4, 8, 16, 32]), |
||||||
|
mask_head=dict( |
||||||
|
type='SCNetMaskHead', |
||||||
|
num_convs=12, |
||||||
|
in_channels=256, |
||||||
|
conv_out_channels=256, |
||||||
|
num_classes=80, |
||||||
|
conv_to_res=True, |
||||||
|
loss_mask=dict( |
||||||
|
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)), |
||||||
|
semantic_roi_extractor=dict( |
||||||
|
type='SingleRoIExtractor', |
||||||
|
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), |
||||||
|
out_channels=256, |
||||||
|
featmap_strides=[8]), |
||||||
|
semantic_head=dict( |
||||||
|
type='SCNetSemanticHead', |
||||||
|
num_ins=5, |
||||||
|
fusion_level=1, |
||||||
|
num_convs=4, |
||||||
|
in_channels=256, |
||||||
|
conv_out_channels=256, |
||||||
|
num_classes=183, |
||||||
|
ignore_label=255, |
||||||
|
loss_weight=0.2, |
||||||
|
conv_to_res=True), |
||||||
|
glbctx_head=dict( |
||||||
|
type='GlobalContextHead', |
||||||
|
num_convs=4, |
||||||
|
in_channels=256, |
||||||
|
conv_out_channels=256, |
||||||
|
num_classes=81, |
||||||
|
loss_weight=3.0, |
||||||
|
conv_to_res=True), |
||||||
|
feat_relay_head=dict( |
||||||
|
type='FeatureRelayHead', |
||||||
|
in_channels=1024, |
||||||
|
out_conv_channels=256, |
||||||
|
roi_feat_size=7, |
||||||
|
scale_factor=2))) |
||||||
|
|
||||||
|
# uncomment below code to enable test time augmentations |
||||||
|
# img_norm_cfg = dict( |
||||||
|
# mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) |
||||||
|
# test_pipeline = [ |
||||||
|
# dict(type='LoadImageFromFile'), |
||||||
|
# dict( |
||||||
|
# type='MultiScaleFlipAug', |
||||||
|
# img_scale=[(600, 900), (800, 1200), (1000, 1500), (1200, 1800), |
||||||
|
# (1400, 2100)], |
||||||
|
# flip=True, |
||||||
|
# transforms=[ |
||||||
|
# dict(type='Resize', keep_ratio=True), |
||||||
|
# dict(type='RandomFlip', flip_ratio=0.5), |
||||||
|
# dict(type='Normalize', **img_norm_cfg), |
||||||
|
# dict(type='Pad', size_divisor=32), |
||||||
|
# dict(type='ImageToTensor', keys=['img']), |
||||||
|
# dict(type='Collect', keys=['img']), |
||||||
|
# ]) |
||||||
|
# ] |
||||||
|
# data = dict( |
||||||
|
# val=dict(pipeline=test_pipeline), |
||||||
|
# test=dict(pipeline=test_pipeline)) |
@ -0,0 +1,4 @@ |
|||||||
|
_base_ = './scnet_r50_fpn_1x_coco.py' |
||||||
|
# learning policy |
||||||
|
lr_config = dict(step=[16, 19]) |
||||||
|
total_epochs = 20 |
@ -0,0 +1,14 @@ |
|||||||
|
_base_ = './scnet_r50_fpn_20e_coco.py' |
||||||
|
model = dict( |
||||||
|
pretrained='open-mmlab://resnext101_64x4d', |
||||||
|
backbone=dict( |
||||||
|
type='ResNeXt', |
||||||
|
depth=101, |
||||||
|
groups=64, |
||||||
|
base_width=4, |
||||||
|
num_stages=4, |
||||||
|
out_indices=(0, 1, 2, 3), |
||||||
|
frozen_stages=1, |
||||||
|
norm_cfg=dict(type='BN', requires_grad=True), |
||||||
|
norm_eval=True, |
||||||
|
style='pytorch')) |
@ -0,0 +1,3 @@ |
|||||||
|
_base_ = './scnet_x101_64x4d_fpn_20e_coco.py' |
||||||
|
data = dict(samples_per_gpu=1, workers_per_gpu=1) |
||||||
|
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) |
@ -0,0 +1,10 @@ |
|||||||
|
from ..builder import DETECTORS |
||||||
|
from .cascade_rcnn import CascadeRCNN |
||||||
|
|
||||||
|
|
||||||
|
@DETECTORS.register_module() |
||||||
|
class SCNet(CascadeRCNN): |
||||||
|
"""Implementation of `SCNet <https://arxiv.org/abs/2012.10150>`_""" |
||||||
|
|
||||||
|
def __init__(self, **kwargs): |
||||||
|
super(SCNet, self).__init__(**kwargs) |
@ -0,0 +1,76 @@ |
|||||||
|
from mmdet.models.builder import HEADS |
||||||
|
from .convfc_bbox_head import ConvFCBBoxHead |
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module() |
||||||
|
class SCNetBBoxHead(ConvFCBBoxHead): |
||||||
|
"""BBox head for `SCNet <https://arxiv.org/abs/2012.10150>`_. |
||||||
|
|
||||||
|
This inherits ``ConvFCBBoxHead`` with modified forward() function, allow us |
||||||
|
to get intermediate shared feature. |
||||||
|
""" |
||||||
|
|
||||||
|
def _forward_shared(self, x): |
||||||
|
"""Forward function for shared part.""" |
||||||
|
if self.num_shared_convs > 0: |
||||||
|
for conv in self.shared_convs: |
||||||
|
x = conv(x) |
||||||
|
|
||||||
|
if self.num_shared_fcs > 0: |
||||||
|
if self.with_avg_pool: |
||||||
|
x = self.avg_pool(x) |
||||||
|
|
||||||
|
x = x.flatten(1) |
||||||
|
|
||||||
|
for fc in self.shared_fcs: |
||||||
|
x = self.relu(fc(x)) |
||||||
|
|
||||||
|
return x |
||||||
|
|
||||||
|
def _forward_cls_reg(self, x): |
||||||
|
"""Forward function for classification and regression parts.""" |
||||||
|
x_cls = x |
||||||
|
x_reg = x |
||||||
|
|
||||||
|
for conv in self.cls_convs: |
||||||
|
x_cls = conv(x_cls) |
||||||
|
if x_cls.dim() > 2: |
||||||
|
if self.with_avg_pool: |
||||||
|
x_cls = self.avg_pool(x_cls) |
||||||
|
x_cls = x_cls.flatten(1) |
||||||
|
for fc in self.cls_fcs: |
||||||
|
x_cls = self.relu(fc(x_cls)) |
||||||
|
|
||||||
|
for conv in self.reg_convs: |
||||||
|
x_reg = conv(x_reg) |
||||||
|
if x_reg.dim() > 2: |
||||||
|
if self.with_avg_pool: |
||||||
|
x_reg = self.avg_pool(x_reg) |
||||||
|
x_reg = x_reg.flatten(1) |
||||||
|
for fc in self.reg_fcs: |
||||||
|
x_reg = self.relu(fc(x_reg)) |
||||||
|
|
||||||
|
cls_score = self.fc_cls(x_cls) if self.with_cls else None |
||||||
|
bbox_pred = self.fc_reg(x_reg) if self.with_reg else None |
||||||
|
|
||||||
|
return cls_score, bbox_pred |
||||||
|
|
||||||
|
def forward(self, x, return_shared_feat=False): |
||||||
|
"""Forward function. |
||||||
|
|
||||||
|
Args: |
||||||
|
x (Tensor): input features |
||||||
|
return_shared_feat (bool): If True, return cls-reg-shared feature. |
||||||
|
|
||||||
|
Return: |
||||||
|
out (tuple[Tensor]): contain ``cls_score`` and ``bbox_pred``, |
||||||
|
if ``return_shared_feat`` is True, append ``x_shared`` to the |
||||||
|
returned tuple. |
||||||
|
""" |
||||||
|
x_shared = self._forward_shared(x) |
||||||
|
out = self._forward_cls_reg(x_shared) |
||||||
|
|
||||||
|
if return_shared_feat: |
||||||
|
out += (x_shared, ) |
||||||
|
|
||||||
|
return out |
@ -1,12 +1,17 @@ |
|||||||
from .coarse_mask_head import CoarseMaskHead |
from .coarse_mask_head import CoarseMaskHead |
||||||
from .fcn_mask_head import FCNMaskHead |
from .fcn_mask_head import FCNMaskHead |
||||||
|
from .feature_relay_head import FeatureRelayHead |
||||||
from .fused_semantic_head import FusedSemanticHead |
from .fused_semantic_head import FusedSemanticHead |
||||||
|
from .global_context_head import GlobalContextHead |
||||||
from .grid_head import GridHead |
from .grid_head import GridHead |
||||||
from .htc_mask_head import HTCMaskHead |
from .htc_mask_head import HTCMaskHead |
||||||
from .mask_point_head import MaskPointHead |
from .mask_point_head import MaskPointHead |
||||||
from .maskiou_head import MaskIoUHead |
from .maskiou_head import MaskIoUHead |
||||||
|
from .scnet_mask_head import SCNetMaskHead |
||||||
|
from .scnet_semantic_head import SCNetSemanticHead |
||||||
|
|
||||||
__all__ = [ |
__all__ = [ |
||||||
'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead', |
'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead', |
||||||
'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead' |
'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead', 'SCNetMaskHead', |
||||||
|
'SCNetSemanticHead', 'GlobalContextHead', 'FeatureRelayHead' |
||||||
] |
] |
||||||
|
@ -0,0 +1,55 @@ |
|||||||
|
import torch.nn as nn |
||||||
|
from mmcv.cnn import kaiming_init |
||||||
|
from mmcv.runner import auto_fp16 |
||||||
|
|
||||||
|
from mmdet.models.builder import HEADS |
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module() |
||||||
|
class FeatureRelayHead(nn.Module): |
||||||
|
"""Feature Relay Head used in SCNet https://arxiv.org/abs/2012.10150. |
||||||
|
|
||||||
|
Args: |
||||||
|
in_channels (int, optional): number of input channels. Default: 256. |
||||||
|
conv_out_channels (int, optional): number of output channels before |
||||||
|
classification layer. Default: 256. |
||||||
|
roi_feat_size (int, optional): roi feat size at box head. Default: 7. |
||||||
|
scale_factor (int, optional): scale factor to match roi feat size |
||||||
|
at mask head. Default: 2. |
||||||
|
""" |
||||||
|
|
||||||
|
def __init__(self, |
||||||
|
in_channels=1024, |
||||||
|
out_conv_channels=256, |
||||||
|
roi_feat_size=7, |
||||||
|
scale_factor=2): |
||||||
|
super(FeatureRelayHead, self).__init__() |
||||||
|
assert isinstance(roi_feat_size, int) |
||||||
|
|
||||||
|
self.in_channels = in_channels |
||||||
|
self.out_conv_channels = out_conv_channels |
||||||
|
self.roi_feat_size = roi_feat_size |
||||||
|
self.out_channels = (roi_feat_size**2) * out_conv_channels |
||||||
|
self.scale_factor = scale_factor |
||||||
|
self.fp16_enabled = False |
||||||
|
|
||||||
|
self.fc = nn.Linear(self.in_channels, self.out_channels) |
||||||
|
self.upsample = nn.Upsample( |
||||||
|
scale_factor=scale_factor, mode='bilinear', align_corners=True) |
||||||
|
|
||||||
|
def init_weights(self): |
||||||
|
"""Init weights for the head.""" |
||||||
|
kaiming_init(self.fc) |
||||||
|
|
||||||
|
@auto_fp16() |
||||||
|
def forward(self, x): |
||||||
|
"""Forward function.""" |
||||||
|
N, in_C = x.shape |
||||||
|
if N > 0: |
||||||
|
out_C = self.out_conv_channels |
||||||
|
out_HW = self.roi_feat_size |
||||||
|
x = self.fc(x) |
||||||
|
x = x.reshape(N, out_C, out_HW, out_HW) |
||||||
|
x = self.upsample(x) |
||||||
|
return x |
||||||
|
return None |
@ -0,0 +1,101 @@ |
|||||||
|
import torch.nn as nn |
||||||
|
from mmcv.cnn import ConvModule |
||||||
|
from mmcv.runner import auto_fp16, force_fp32 |
||||||
|
|
||||||
|
from mmdet.models.builder import HEADS |
||||||
|
from mmdet.models.utils import ResLayer, SimplifiedBasicBlock |
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module() |
||||||
|
class GlobalContextHead(nn.Module): |
||||||
|
"""Global context head used in SCNet https://arxiv.org/abs/2012.10150. |
||||||
|
|
||||||
|
Args: |
||||||
|
num_convs (int, optional): number of convolutional layer in GlbCtxHead. |
||||||
|
Default: 4. |
||||||
|
in_channels (int, optional): number of input channels. Default: 256. |
||||||
|
conv_out_channels (int, optional): number of output channels before |
||||||
|
classification layer. Default: 256. |
||||||
|
loss_weight (float, optional): global context loss weight. Default: 1. |
||||||
|
conv_cfg (dict, optional): config to init conv layer. Default: None. |
||||||
|
norm_cfg (dict, optional): config to init norm layer. Default: None. |
||||||
|
conv_to_res (bool, optional): if True, 2 convs will be grouped into |
||||||
|
1 `SimplifiedBasicBlock` using a skip connection. Default: False. |
||||||
|
""" |
||||||
|
|
||||||
|
def __init__(self, |
||||||
|
num_convs=4, |
||||||
|
in_channels=256, |
||||||
|
conv_out_channels=256, |
||||||
|
num_classes=81, |
||||||
|
loss_weight=1.0, |
||||||
|
conv_cfg=None, |
||||||
|
norm_cfg=None, |
||||||
|
conv_to_res=False): |
||||||
|
super(GlobalContextHead, self).__init__() |
||||||
|
self.num_convs = num_convs |
||||||
|
self.in_channels = in_channels |
||||||
|
self.conv_out_channels = conv_out_channels |
||||||
|
self.num_classes = num_classes |
||||||
|
self.loss_weight = loss_weight |
||||||
|
self.conv_cfg = conv_cfg |
||||||
|
self.norm_cfg = norm_cfg |
||||||
|
self.conv_to_res = conv_to_res |
||||||
|
self.fp16_enabled = False |
||||||
|
|
||||||
|
if self.conv_to_res: |
||||||
|
num_res_blocks = num_convs // 2 |
||||||
|
self.convs = ResLayer( |
||||||
|
SimplifiedBasicBlock, |
||||||
|
in_channels, |
||||||
|
self.conv_out_channels, |
||||||
|
num_res_blocks, |
||||||
|
conv_cfg=self.conv_cfg, |
||||||
|
norm_cfg=self.norm_cfg) |
||||||
|
self.num_convs = num_res_blocks |
||||||
|
else: |
||||||
|
self.convs = nn.ModuleList() |
||||||
|
for i in range(self.num_convs): |
||||||
|
in_channels = self.in_channels if i == 0 else conv_out_channels |
||||||
|
self.convs.append( |
||||||
|
ConvModule( |
||||||
|
in_channels, |
||||||
|
conv_out_channels, |
||||||
|
3, |
||||||
|
padding=1, |
||||||
|
conv_cfg=self.conv_cfg, |
||||||
|
norm_cfg=self.norm_cfg)) |
||||||
|
|
||||||
|
self.pool = nn.AdaptiveAvgPool2d(1) |
||||||
|
self.fc = nn.Linear(conv_out_channels, num_classes) |
||||||
|
|
||||||
|
self.criterion = nn.BCEWithLogitsLoss() |
||||||
|
|
||||||
|
def init_weights(self): |
||||||
|
"""Init weights for the head.""" |
||||||
|
nn.init.normal_(self.fc.weight, 0, 0.01) |
||||||
|
nn.init.constant_(self.fc.bias, 0) |
||||||
|
|
||||||
|
@auto_fp16() |
||||||
|
def forward(self, feats): |
||||||
|
"""Forward function.""" |
||||||
|
x = feats[-1] |
||||||
|
for i in range(self.num_convs): |
||||||
|
x = self.convs[i](x) |
||||||
|
x = self.pool(x) |
||||||
|
|
||||||
|
# multi-class prediction |
||||||
|
mc_pred = x.reshape(x.size(0), -1) |
||||||
|
mc_pred = self.fc(mc_pred) |
||||||
|
|
||||||
|
return mc_pred, x |
||||||
|
|
||||||
|
@force_fp32(apply_to=('pred', )) |
||||||
|
def loss(self, pred, labels): |
||||||
|
"""Loss function.""" |
||||||
|
labels = [lbl.unique() for lbl in labels] |
||||||
|
targets = pred.new_zeros(pred.size()) |
||||||
|
for i, label in enumerate(labels): |
||||||
|
targets[i, label] = 1.0 |
||||||
|
loss = self.loss_weight * self.criterion(pred, targets) |
||||||
|
return loss |
@ -0,0 +1,27 @@ |
|||||||
|
from mmdet.models.builder import HEADS |
||||||
|
from mmdet.models.utils import ResLayer, SimplifiedBasicBlock |
||||||
|
from .fcn_mask_head import FCNMaskHead |
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module() |
||||||
|
class SCNetMaskHead(FCNMaskHead): |
||||||
|
"""Mask head for `SCNet <https://arxiv.org/abs/2012.10150>`_. |
||||||
|
|
||||||
|
Args: |
||||||
|
conv_to_res (bool, optional): if True, change the conv layers to |
||||||
|
``SimplifiedBasicBlock``. |
||||||
|
""" |
||||||
|
|
||||||
|
def __init__(self, conv_to_res=True, **kwargs): |
||||||
|
super(SCNetMaskHead, self).__init__(**kwargs) |
||||||
|
self.conv_to_res = conv_to_res |
||||||
|
if conv_to_res: |
||||||
|
assert self.conv_kernel_size == 3 |
||||||
|
self.num_res_blocks = self.num_convs // 2 |
||||||
|
self.convs = ResLayer( |
||||||
|
SimplifiedBasicBlock, |
||||||
|
self.in_channels, |
||||||
|
self.conv_out_channels, |
||||||
|
self.num_res_blocks, |
||||||
|
conv_cfg=self.conv_cfg, |
||||||
|
norm_cfg=self.norm_cfg) |
@ -0,0 +1,27 @@ |
|||||||
|
from mmdet.models.builder import HEADS |
||||||
|
from mmdet.models.utils import ResLayer, SimplifiedBasicBlock |
||||||
|
from .fused_semantic_head import FusedSemanticHead |
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module() |
||||||
|
class SCNetSemanticHead(FusedSemanticHead): |
||||||
|
"""Mask head for `SCNet <https://arxiv.org/abs/2012.10150>`_. |
||||||
|
|
||||||
|
Args: |
||||||
|
conv_to_res (bool, optional): if True, change the conv layers to |
||||||
|
``SimplifiedBasicBlock``. |
||||||
|
""" |
||||||
|
|
||||||
|
def __init__(self, conv_to_res=True, **kwargs): |
||||||
|
super(SCNetSemanticHead, self).__init__(**kwargs) |
||||||
|
self.conv_to_res = conv_to_res |
||||||
|
if self.conv_to_res: |
||||||
|
num_res_blocks = self.num_convs // 2 |
||||||
|
self.convs = ResLayer( |
||||||
|
SimplifiedBasicBlock, |
||||||
|
self.in_channels, |
||||||
|
self.conv_out_channels, |
||||||
|
num_res_blocks, |
||||||
|
conv_cfg=self.conv_cfg, |
||||||
|
norm_cfg=self.norm_cfg) |
||||||
|
self.num_convs = num_res_blocks |
@ -0,0 +1,582 @@ |
|||||||
|
import torch |
||||||
|
import torch.nn.functional as F |
||||||
|
|
||||||
|
from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes, |
||||||
|
merge_aug_masks, multiclass_nms) |
||||||
|
from ..builder import HEADS, build_head, build_roi_extractor |
||||||
|
from .cascade_roi_head import CascadeRoIHead |
||||||
|
|
||||||
|
|
||||||
|
@HEADS.register_module() |
||||||
|
class SCNetRoIHead(CascadeRoIHead): |
||||||
|
"""RoIHead for `SCNet <https://arxiv.org/abs/2012.10150>`_. |
||||||
|
|
||||||
|
Args: |
||||||
|
num_stages (int): number of cascade stages. |
||||||
|
stage_loss_weights (list): loss weight of cascade stages. |
||||||
|
semantic_roi_extractor (dict): config to init semantic roi extractor. |
||||||
|
semantic_head (dict): config to init semantic head. |
||||||
|
feat_relay_head (dict): config to init feature_relay_head. |
||||||
|
glbctx_head (dict): config to init global context head. |
||||||
|
""" |
||||||
|
|
||||||
|
def __init__(self, |
||||||
|
num_stages, |
||||||
|
stage_loss_weights, |
||||||
|
semantic_roi_extractor=None, |
||||||
|
semantic_head=None, |
||||||
|
feat_relay_head=None, |
||||||
|
glbctx_head=None, |
||||||
|
**kwargs): |
||||||
|
super(SCNetRoIHead, self).__init__(num_stages, stage_loss_weights, |
||||||
|
**kwargs) |
||||||
|
assert self.with_bbox and self.with_mask |
||||||
|
assert not self.with_shared_head # shared head is not supported |
||||||
|
|
||||||
|
if semantic_head is not None: |
||||||
|
self.semantic_roi_extractor = build_roi_extractor( |
||||||
|
semantic_roi_extractor) |
||||||
|
self.semantic_head = build_head(semantic_head) |
||||||
|
|
||||||
|
if feat_relay_head is not None: |
||||||
|
self.feat_relay_head = build_head(feat_relay_head) |
||||||
|
|
||||||
|
if glbctx_head is not None: |
||||||
|
self.glbctx_head = build_head(glbctx_head) |
||||||
|
|
||||||
|
def init_mask_head(self, mask_roi_extractor, mask_head): |
||||||
|
"""Initialize ``mask_head``""" |
||||||
|
if mask_roi_extractor is not None: |
||||||
|
self.mask_roi_extractor = build_roi_extractor(mask_roi_extractor) |
||||||
|
self.mask_head = build_head(mask_head) |
||||||
|
|
||||||
|
def init_weights(self, pretrained): |
||||||
|
"""Initialize the weights in head. |
||||||
|
|
||||||
|
Args: |
||||||
|
pretrained (str, optional): Path to pre-trained weights. |
||||||
|
Defaults to None. |
||||||
|
""" |
||||||
|
for i in range(self.num_stages): |
||||||
|
if self.with_bbox: |
||||||
|
self.bbox_roi_extractor[i].init_weights() |
||||||
|
self.bbox_head[i].init_weights() |
||||||
|
if self.with_mask: |
||||||
|
self.mask_roi_extractor.init_weights() |
||||||
|
self.mask_head.init_weights() |
||||||
|
if self.with_semantic: |
||||||
|
self.semantic_head.init_weights() |
||||||
|
if self.with_glbctx: |
||||||
|
self.glbctx_head.init_weights() |
||||||
|
if self.with_feat_relay: |
||||||
|
self.feat_relay_head.init_weights() |
||||||
|
|
||||||
|
@property |
||||||
|
def with_semantic(self): |
||||||
|
"""bool: whether the head has semantic head""" |
||||||
|
return hasattr(self, |
||||||
|
'semantic_head') and self.semantic_head is not None |
||||||
|
|
||||||
|
@property |
||||||
|
def with_feat_relay(self): |
||||||
|
"""bool: whether the head has feature relay head""" |
||||||
|
return (hasattr(self, 'feat_relay_head') |
||||||
|
and self.feat_relay_head is not None) |
||||||
|
|
||||||
|
@property |
||||||
|
def with_glbctx(self): |
||||||
|
"""bool: whether the head has global context head""" |
||||||
|
return hasattr(self, 'glbctx_head') and self.glbctx_head is not None |
||||||
|
|
||||||
|
def _fuse_glbctx(self, roi_feats, glbctx_feat, rois): |
||||||
|
"""Fuse global context feats with roi feats.""" |
||||||
|
assert roi_feats.size(0) == rois.size(0) |
||||||
|
img_inds = torch.unique(rois[:, 0].cpu(), sorted=True).long() |
||||||
|
fused_feats = torch.zeros_like(roi_feats) |
||||||
|
for img_id in img_inds: |
||||||
|
inds = (rois[:, 0] == img_id.item()) |
||||||
|
fused_feats[inds] = roi_feats[inds] + glbctx_feat[img_id] |
||||||
|
return fused_feats |
||||||
|
|
||||||
|
def _slice_pos_feats(self, feats, sampling_results): |
||||||
|
"""Get features from pos rois.""" |
||||||
|
num_rois = [res.bboxes.size(0) for res in sampling_results] |
||||||
|
num_pos_rois = [res.pos_bboxes.size(0) for res in sampling_results] |
||||||
|
inds = torch.zeros(sum(num_rois), dtype=torch.bool) |
||||||
|
start = 0 |
||||||
|
for i in range(len(num_rois)): |
||||||
|
start = 0 if i == 0 else start + num_rois[i - 1] |
||||||
|
stop = start + num_pos_rois[i] |
||||||
|
inds[start:stop] = 1 |
||||||
|
sliced_feats = feats[inds] |
||||||
|
return sliced_feats |
||||||
|
|
||||||
|
def _bbox_forward(self, |
||||||
|
stage, |
||||||
|
x, |
||||||
|
rois, |
||||||
|
semantic_feat=None, |
||||||
|
glbctx_feat=None): |
||||||
|
"""Box head forward function used in both training and testing.""" |
||||||
|
bbox_roi_extractor = self.bbox_roi_extractor[stage] |
||||||
|
bbox_head = self.bbox_head[stage] |
||||||
|
bbox_feats = bbox_roi_extractor( |
||||||
|
x[:len(bbox_roi_extractor.featmap_strides)], rois) |
||||||
|
if self.with_semantic and semantic_feat is not None: |
||||||
|
bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat], |
||||||
|
rois) |
||||||
|
if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]: |
||||||
|
bbox_semantic_feat = F.adaptive_avg_pool2d( |
||||||
|
bbox_semantic_feat, bbox_feats.shape[-2:]) |
||||||
|
bbox_feats += bbox_semantic_feat |
||||||
|
if self.with_glbctx and glbctx_feat is not None: |
||||||
|
bbox_feats = self._fuse_glbctx(bbox_feats, glbctx_feat, rois) |
||||||
|
cls_score, bbox_pred, relayed_feat = bbox_head( |
||||||
|
bbox_feats, return_shared_feat=True) |
||||||
|
|
||||||
|
bbox_results = dict( |
||||||
|
cls_score=cls_score, |
||||||
|
bbox_pred=bbox_pred, |
||||||
|
relayed_feat=relayed_feat) |
||||||
|
return bbox_results |
||||||
|
|
||||||
|
def _mask_forward(self, |
||||||
|
x, |
||||||
|
rois, |
||||||
|
semantic_feat=None, |
||||||
|
glbctx_feat=None, |
||||||
|
relayed_feat=None): |
||||||
|
"""Mask head forward function used in both training and testing.""" |
||||||
|
mask_feats = self.mask_roi_extractor( |
||||||
|
x[:self.mask_roi_extractor.num_inputs], rois) |
||||||
|
if self.with_semantic and semantic_feat is not None: |
||||||
|
mask_semantic_feat = self.semantic_roi_extractor([semantic_feat], |
||||||
|
rois) |
||||||
|
if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]: |
||||||
|
mask_semantic_feat = F.adaptive_avg_pool2d( |
||||||
|
mask_semantic_feat, mask_feats.shape[-2:]) |
||||||
|
mask_feats += mask_semantic_feat |
||||||
|
if self.with_glbctx and glbctx_feat is not None: |
||||||
|
mask_feats = self._fuse_glbctx(mask_feats, glbctx_feat, rois) |
||||||
|
if self.with_feat_relay and relayed_feat is not None: |
||||||
|
mask_feats = mask_feats + relayed_feat |
||||||
|
mask_pred = self.mask_head(mask_feats) |
||||||
|
mask_results = dict(mask_pred=mask_pred) |
||||||
|
|
||||||
|
return mask_results |
||||||
|
|
||||||
|
def _bbox_forward_train(self, |
||||||
|
stage, |
||||||
|
x, |
||||||
|
sampling_results, |
||||||
|
gt_bboxes, |
||||||
|
gt_labels, |
||||||
|
rcnn_train_cfg, |
||||||
|
semantic_feat=None, |
||||||
|
glbctx_feat=None): |
||||||
|
"""Run forward function and calculate loss for box head in training.""" |
||||||
|
bbox_head = self.bbox_head[stage] |
||||||
|
rois = bbox2roi([res.bboxes for res in sampling_results]) |
||||||
|
bbox_results = self._bbox_forward( |
||||||
|
stage, |
||||||
|
x, |
||||||
|
rois, |
||||||
|
semantic_feat=semantic_feat, |
||||||
|
glbctx_feat=glbctx_feat) |
||||||
|
|
||||||
|
bbox_targets = bbox_head.get_targets(sampling_results, gt_bboxes, |
||||||
|
gt_labels, rcnn_train_cfg) |
||||||
|
loss_bbox = bbox_head.loss(bbox_results['cls_score'], |
||||||
|
bbox_results['bbox_pred'], rois, |
||||||
|
*bbox_targets) |
||||||
|
|
||||||
|
bbox_results.update( |
||||||
|
loss_bbox=loss_bbox, rois=rois, bbox_targets=bbox_targets) |
||||||
|
return bbox_results |
||||||
|
|
||||||
|
def _mask_forward_train(self, |
||||||
|
x, |
||||||
|
sampling_results, |
||||||
|
gt_masks, |
||||||
|
rcnn_train_cfg, |
||||||
|
semantic_feat=None, |
||||||
|
glbctx_feat=None, |
||||||
|
relayed_feat=None): |
||||||
|
"""Run forward function and calculate loss for mask head in |
||||||
|
training.""" |
||||||
|
pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results]) |
||||||
|
mask_results = self._mask_forward( |
||||||
|
x, |
||||||
|
pos_rois, |
||||||
|
semantic_feat=semantic_feat, |
||||||
|
glbctx_feat=glbctx_feat, |
||||||
|
relayed_feat=relayed_feat) |
||||||
|
|
||||||
|
mask_targets = self.mask_head.get_targets(sampling_results, gt_masks, |
||||||
|
rcnn_train_cfg) |
||||||
|
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) |
||||||
|
loss_mask = self.mask_head.loss(mask_results['mask_pred'], |
||||||
|
mask_targets, pos_labels) |
||||||
|
|
||||||
|
mask_results = loss_mask |
||||||
|
return mask_results |
||||||
|
|
||||||
|
def forward_train(self, |
||||||
|
x, |
||||||
|
img_metas, |
||||||
|
proposal_list, |
||||||
|
gt_bboxes, |
||||||
|
gt_labels, |
||||||
|
gt_bboxes_ignore=None, |
||||||
|
gt_masks=None, |
||||||
|
gt_semantic_seg=None): |
||||||
|
""" |
||||||
|
Args: |
||||||
|
x (list[Tensor]): list of multi-level img features. |
||||||
|
|
||||||
|
img_metas (list[dict]): list of image info dict where each dict |
||||||
|
has: 'img_shape', 'scale_factor', 'flip', and may also contain |
||||||
|
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. |
||||||
|
For details on the values of these keys see |
||||||
|
`mmdet/datasets/pipelines/formatting.py:Collect`. |
||||||
|
|
||||||
|
proposal_list (list[Tensors]): list of region proposals. |
||||||
|
|
||||||
|
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with |
||||||
|
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. |
||||||
|
|
||||||
|
gt_labels (list[Tensor]): class indices corresponding to each box |
||||||
|
|
||||||
|
gt_bboxes_ignore (None, list[Tensor]): specify which bounding |
||||||
|
boxes can be ignored when computing the loss. |
||||||
|
|
||||||
|
gt_masks (None, Tensor) : true segmentation masks for each box |
||||||
|
used if the architecture supports a segmentation task. |
||||||
|
|
||||||
|
gt_semantic_seg (None, list[Tensor]): semantic segmentation masks |
||||||
|
used if the architecture supports semantic segmentation task. |
||||||
|
|
||||||
|
Returns: |
||||||
|
dict[str, Tensor]: a dictionary of loss components |
||||||
|
""" |
||||||
|
losses = dict() |
||||||
|
|
||||||
|
# semantic segmentation branch |
||||||
|
if self.with_semantic: |
||||||
|
semantic_pred, semantic_feat = self.semantic_head(x) |
||||||
|
loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_seg) |
||||||
|
losses['loss_semantic_seg'] = loss_seg |
||||||
|
else: |
||||||
|
semantic_feat = None |
||||||
|
|
||||||
|
# global context branch |
||||||
|
if self.with_glbctx: |
||||||
|
mc_pred, glbctx_feat = self.glbctx_head(x) |
||||||
|
loss_glbctx = self.glbctx_head.loss(mc_pred, gt_labels) |
||||||
|
losses['loss_glbctx'] = loss_glbctx |
||||||
|
else: |
||||||
|
glbctx_feat = None |
||||||
|
|
||||||
|
for i in range(self.num_stages): |
||||||
|
self.current_stage = i |
||||||
|
rcnn_train_cfg = self.train_cfg[i] |
||||||
|
lw = self.stage_loss_weights[i] |
||||||
|
|
||||||
|
# assign gts and sample proposals |
||||||
|
sampling_results = [] |
||||||
|
bbox_assigner = self.bbox_assigner[i] |
||||||
|
bbox_sampler = self.bbox_sampler[i] |
||||||
|
num_imgs = len(img_metas) |
||||||
|
if gt_bboxes_ignore is None: |
||||||
|
gt_bboxes_ignore = [None for _ in range(num_imgs)] |
||||||
|
|
||||||
|
for j in range(num_imgs): |
||||||
|
assign_result = bbox_assigner.assign(proposal_list[j], |
||||||
|
gt_bboxes[j], |
||||||
|
gt_bboxes_ignore[j], |
||||||
|
gt_labels[j]) |
||||||
|
sampling_result = bbox_sampler.sample( |
||||||
|
assign_result, |
||||||
|
proposal_list[j], |
||||||
|
gt_bboxes[j], |
||||||
|
gt_labels[j], |
||||||
|
feats=[lvl_feat[j][None] for lvl_feat in x]) |
||||||
|
sampling_results.append(sampling_result) |
||||||
|
|
||||||
|
bbox_results = \ |
||||||
|
self._bbox_forward_train( |
||||||
|
i, x, sampling_results, gt_bboxes, gt_labels, |
||||||
|
rcnn_train_cfg, semantic_feat, glbctx_feat) |
||||||
|
roi_labels = bbox_results['bbox_targets'][0] |
||||||
|
|
||||||
|
for name, value in bbox_results['loss_bbox'].items(): |
||||||
|
losses[f's{i}.{name}'] = ( |
||||||
|
value * lw if 'loss' in name else value) |
||||||
|
|
||||||
|
# refine boxes |
||||||
|
if i < self.num_stages - 1: |
||||||
|
pos_is_gts = [res.pos_is_gt for res in sampling_results] |
||||||
|
with torch.no_grad(): |
||||||
|
proposal_list = self.bbox_head[i].refine_bboxes( |
||||||
|
bbox_results['rois'], roi_labels, |
||||||
|
bbox_results['bbox_pred'], pos_is_gts, img_metas) |
||||||
|
|
||||||
|
if self.with_feat_relay: |
||||||
|
relayed_feat = self._slice_pos_feats(bbox_results['relayed_feat'], |
||||||
|
sampling_results) |
||||||
|
relayed_feat = self.feat_relay_head(relayed_feat) |
||||||
|
else: |
||||||
|
relayed_feat = None |
||||||
|
|
||||||
|
mask_results = self._mask_forward_train(x, sampling_results, gt_masks, |
||||||
|
rcnn_train_cfg, semantic_feat, |
||||||
|
glbctx_feat, relayed_feat) |
||||||
|
mask_lw = sum(self.stage_loss_weights) |
||||||
|
losses['loss_mask'] = mask_lw * mask_results['loss_mask'] |
||||||
|
|
||||||
|
return losses |
||||||
|
|
||||||
|
def simple_test(self, x, proposal_list, img_metas, rescale=False): |
||||||
|
"""Test without augmentation.""" |
||||||
|
if self.with_semantic: |
||||||
|
_, semantic_feat = self.semantic_head(x) |
||||||
|
else: |
||||||
|
semantic_feat = None |
||||||
|
|
||||||
|
if self.with_glbctx: |
||||||
|
mc_pred, glbctx_feat = self.glbctx_head(x) |
||||||
|
else: |
||||||
|
glbctx_feat = None |
||||||
|
|
||||||
|
num_imgs = len(proposal_list) |
||||||
|
img_shapes = tuple(meta['img_shape'] for meta in img_metas) |
||||||
|
ori_shapes = tuple(meta['ori_shape'] for meta in img_metas) |
||||||
|
scale_factors = tuple(meta['scale_factor'] for meta in img_metas) |
||||||
|
|
||||||
|
# "ms" in variable names means multi-stage |
||||||
|
ms_scores = [] |
||||||
|
rcnn_test_cfg = self.test_cfg |
||||||
|
|
||||||
|
rois = bbox2roi(proposal_list) |
||||||
|
for i in range(self.num_stages): |
||||||
|
bbox_head = self.bbox_head[i] |
||||||
|
bbox_results = self._bbox_forward( |
||||||
|
i, |
||||||
|
x, |
||||||
|
rois, |
||||||
|
semantic_feat=semantic_feat, |
||||||
|
glbctx_feat=glbctx_feat) |
||||||
|
# split batch bbox prediction back to each image |
||||||
|
cls_score = bbox_results['cls_score'] |
||||||
|
bbox_pred = bbox_results['bbox_pred'] |
||||||
|
num_proposals_per_img = tuple(len(p) for p in proposal_list) |
||||||
|
rois = rois.split(num_proposals_per_img, 0) |
||||||
|
cls_score = cls_score.split(num_proposals_per_img, 0) |
||||||
|
bbox_pred = bbox_pred.split(num_proposals_per_img, 0) |
||||||
|
ms_scores.append(cls_score) |
||||||
|
|
||||||
|
if i < self.num_stages - 1: |
||||||
|
bbox_label = [s[:, :-1].argmax(dim=1) for s in cls_score] |
||||||
|
rois = torch.cat([ |
||||||
|
bbox_head.regress_by_class(rois[i], bbox_label[i], |
||||||
|
bbox_pred[i], img_metas[i]) |
||||||
|
for i in range(num_imgs) |
||||||
|
]) |
||||||
|
|
||||||
|
# average scores of each image by stages |
||||||
|
cls_score = [ |
||||||
|
sum([score[i] for score in ms_scores]) / float(len(ms_scores)) |
||||||
|
for i in range(num_imgs) |
||||||
|
] |
||||||
|
|
||||||
|
# apply bbox post-processing to each image individually |
||||||
|
det_bboxes = [] |
||||||
|
det_labels = [] |
||||||
|
for i in range(num_imgs): |
||||||
|
det_bbox, det_label = self.bbox_head[-1].get_bboxes( |
||||||
|
rois[i], |
||||||
|
cls_score[i], |
||||||
|
bbox_pred[i], |
||||||
|
img_shapes[i], |
||||||
|
scale_factors[i], |
||||||
|
rescale=rescale, |
||||||
|
cfg=rcnn_test_cfg) |
||||||
|
det_bboxes.append(det_bbox) |
||||||
|
det_labels.append(det_label) |
||||||
|
det_bbox_results = [ |
||||||
|
bbox2result(det_bboxes[i], det_labels[i], |
||||||
|
self.bbox_head[-1].num_classes) |
||||||
|
for i in range(num_imgs) |
||||||
|
] |
||||||
|
|
||||||
|
if self.with_mask: |
||||||
|
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes): |
||||||
|
mask_classes = self.mask_head.num_classes |
||||||
|
det_segm_results = [[[] for _ in range(mask_classes)] |
||||||
|
for _ in range(num_imgs)] |
||||||
|
else: |
||||||
|
if rescale and not isinstance(scale_factors[0], float): |
||||||
|
scale_factors = [ |
||||||
|
torch.from_numpy(scale_factor).to(det_bboxes[0].device) |
||||||
|
for scale_factor in scale_factors |
||||||
|
] |
||||||
|
_bboxes = [ |
||||||
|
det_bboxes[i][:, :4] * |
||||||
|
scale_factors[i] if rescale else det_bboxes[i] |
||||||
|
for i in range(num_imgs) |
||||||
|
] |
||||||
|
mask_rois = bbox2roi(_bboxes) |
||||||
|
|
||||||
|
# get relay feature on mask_rois |
||||||
|
bbox_results = self._bbox_forward( |
||||||
|
-1, |
||||||
|
x, |
||||||
|
mask_rois, |
||||||
|
semantic_feat=semantic_feat, |
||||||
|
glbctx_feat=glbctx_feat) |
||||||
|
relayed_feat = bbox_results['relayed_feat'] |
||||||
|
relayed_feat = self.feat_relay_head(relayed_feat) |
||||||
|
|
||||||
|
mask_results = self._mask_forward( |
||||||
|
x, |
||||||
|
mask_rois, |
||||||
|
semantic_feat=semantic_feat, |
||||||
|
glbctx_feat=glbctx_feat, |
||||||
|
relayed_feat=relayed_feat) |
||||||
|
mask_pred = mask_results['mask_pred'] |
||||||
|
|
||||||
|
# split batch mask prediction back to each image |
||||||
|
num_bbox_per_img = tuple(len(_bbox) for _bbox in _bboxes) |
||||||
|
mask_preds = mask_pred.split(num_bbox_per_img, 0) |
||||||
|
|
||||||
|
# apply mask post-processing to each image individually |
||||||
|
det_segm_results = [] |
||||||
|
for i in range(num_imgs): |
||||||
|
if det_bboxes[i].shape[0] == 0: |
||||||
|
det_segm_results.append( |
||||||
|
[[] for _ in range(self.mask_head.num_classes)]) |
||||||
|
else: |
||||||
|
segm_result = self.mask_head.get_seg_masks( |
||||||
|
mask_preds[i], _bboxes[i], det_labels[i], |
||||||
|
self.test_cfg, ori_shapes[i], scale_factors[i], |
||||||
|
rescale) |
||||||
|
det_segm_results.append(segm_result) |
||||||
|
|
||||||
|
# return results |
||||||
|
if self.with_mask: |
||||||
|
return list(zip(det_bbox_results, det_segm_results)) |
||||||
|
else: |
||||||
|
return det_bbox_results |
||||||
|
|
||||||
|
def aug_test(self, img_feats, proposal_list, img_metas, rescale=False): |
||||||
|
if self.with_semantic: |
||||||
|
semantic_feats = [ |
||||||
|
self.semantic_head(feat)[1] for feat in img_feats |
||||||
|
] |
||||||
|
else: |
||||||
|
semantic_feats = [None] * len(img_metas) |
||||||
|
|
||||||
|
if self.with_glbctx: |
||||||
|
glbctx_feats = [self.glbctx_head(feat)[1] for feat in img_feats] |
||||||
|
else: |
||||||
|
glbctx_feats = [None] * len(img_metas) |
||||||
|
|
||||||
|
rcnn_test_cfg = self.test_cfg |
||||||
|
aug_bboxes = [] |
||||||
|
aug_scores = [] |
||||||
|
for x, img_meta, semantic_feat, glbctx_feat in zip( |
||||||
|
img_feats, img_metas, semantic_feats, glbctx_feats): |
||||||
|
# only one image in the batch |
||||||
|
img_shape = img_meta[0]['img_shape'] |
||||||
|
scale_factor = img_meta[0]['scale_factor'] |
||||||
|
flip = img_meta[0]['flip'] |
||||||
|
|
||||||
|
proposals = bbox_mapping(proposal_list[0][:, :4], img_shape, |
||||||
|
scale_factor, flip) |
||||||
|
# "ms" in variable names means multi-stage |
||||||
|
ms_scores = [] |
||||||
|
|
||||||
|
rois = bbox2roi([proposals]) |
||||||
|
for i in range(self.num_stages): |
||||||
|
bbox_head = self.bbox_head[i] |
||||||
|
bbox_results = self._bbox_forward( |
||||||
|
i, |
||||||
|
x, |
||||||
|
rois, |
||||||
|
semantic_feat=semantic_feat, |
||||||
|
glbctx_feat=glbctx_feat) |
||||||
|
ms_scores.append(bbox_results['cls_score']) |
||||||
|
if i < self.num_stages - 1: |
||||||
|
bbox_label = bbox_results['cls_score'].argmax(dim=1) |
||||||
|
rois = bbox_head.regress_by_class( |
||||||
|
rois, bbox_label, bbox_results['bbox_pred'], |
||||||
|
img_meta[0]) |
||||||
|
|
||||||
|
cls_score = sum(ms_scores) / float(len(ms_scores)) |
||||||
|
bboxes, scores = self.bbox_head[-1].get_bboxes( |
||||||
|
rois, |
||||||
|
cls_score, |
||||||
|
bbox_results['bbox_pred'], |
||||||
|
img_shape, |
||||||
|
scale_factor, |
||||||
|
rescale=False, |
||||||
|
cfg=None) |
||||||
|
aug_bboxes.append(bboxes) |
||||||
|
aug_scores.append(scores) |
||||||
|
|
||||||
|
# after merging, bboxes will be rescaled to the original image size |
||||||
|
merged_bboxes, merged_scores = merge_aug_bboxes( |
||||||
|
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg) |
||||||
|
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, |
||||||
|
rcnn_test_cfg.score_thr, |
||||||
|
rcnn_test_cfg.nms, |
||||||
|
rcnn_test_cfg.max_per_img) |
||||||
|
|
||||||
|
det_bbox_results = bbox2result(det_bboxes, det_labels, |
||||||
|
self.bbox_head[-1].num_classes) |
||||||
|
|
||||||
|
if self.with_mask: |
||||||
|
if det_bboxes.shape[0] == 0: |
||||||
|
det_segm_results = [[] |
||||||
|
for _ in range(self.mask_head.num_classes)] |
||||||
|
else: |
||||||
|
aug_masks = [] |
||||||
|
for x, img_meta, semantic_feat, glbctx_feat in zip( |
||||||
|
img_feats, img_metas, semantic_feats, glbctx_feats): |
||||||
|
img_shape = img_meta[0]['img_shape'] |
||||||
|
scale_factor = img_meta[0]['scale_factor'] |
||||||
|
flip = img_meta[0]['flip'] |
||||||
|
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape, |
||||||
|
scale_factor, flip) |
||||||
|
mask_rois = bbox2roi([_bboxes]) |
||||||
|
# get relay feature on mask_rois |
||||||
|
bbox_results = self._bbox_forward( |
||||||
|
-1, |
||||||
|
x, |
||||||
|
mask_rois, |
||||||
|
semantic_feat=semantic_feat, |
||||||
|
glbctx_feat=glbctx_feat) |
||||||
|
relayed_feat = bbox_results['relayed_feat'] |
||||||
|
relayed_feat = self.feat_relay_head(relayed_feat) |
||||||
|
mask_results = self._mask_forward( |
||||||
|
x, |
||||||
|
mask_rois, |
||||||
|
semantic_feat=semantic_feat, |
||||||
|
glbctx_feat=glbctx_feat, |
||||||
|
relayed_feat=relayed_feat) |
||||||
|
mask_pred = mask_results['mask_pred'] |
||||||
|
aug_masks.append(mask_pred.sigmoid().cpu().numpy()) |
||||||
|
merged_masks = merge_aug_masks(aug_masks, img_metas, |
||||||
|
self.test_cfg) |
||||||
|
ori_shape = img_metas[0][0]['ori_shape'] |
||||||
|
det_segm_results = self.mask_head.get_seg_masks( |
||||||
|
merged_masks, |
||||||
|
det_bboxes, |
||||||
|
det_labels, |
||||||
|
rcnn_test_cfg, |
||||||
|
ori_shape, |
||||||
|
scale_factor=1.0, |
||||||
|
rescale=False) |
||||||
|
return [(det_bbox_results, det_segm_results)] |
||||||
|
else: |
||||||
|
return [det_bbox_results] |
Loading…
Reference in new issue