Code for AAAI 2021 paper. "SCNet: Training Inference Sample Consistency for Instance Segmentation" (#4356)
* add scnet * add more configs and minors * fix unittest for scnet * update docstring * add forward test for htc and scnet * fix build on pytorch 1.3 for empty tensor * update doc string and minor refactor * update unit-test and use inheritance for SCNet heads * support scnet tta * minor * update readme * minor fixespull/4546/head
parent
31809ece87
commit
40f168937d
24 changed files with 1294 additions and 38 deletions
@ -0,0 +1,51 @@ |
||||
# SCNet |
||||
|
||||
## Introduction |
||||
|
||||
[ALGORITHM] |
||||
|
||||
We provide the code for reproducing experiment results of [SCNet](https://arxiv.org/abs/2012.10150). |
||||
|
||||
``` |
||||
@inproceedings{vu2019cascade, |
||||
title={SCNet: Training Inference Sample Consistency for Instance Segmentation}, |
||||
author={Vu, Thang and Haeyong, Kang and Yoo, Chang D}, |
||||
booktitle={AAAI}, |
||||
year={2021} |
||||
} |
||||
``` |
||||
|
||||
## Dataset |
||||
|
||||
SCNet requires COCO and COCO-stuff dataset for training. You need to download and extract it in the COCO dataset path. |
||||
The directory should be like this. |
||||
|
||||
```none |
||||
mmdetection |
||||
├── mmdet |
||||
├── tools |
||||
├── configs |
||||
├── data |
||||
│ ├── coco |
||||
│ │ ├── annotations |
||||
│ │ ├── train2017 |
||||
│ │ ├── val2017 |
||||
│ │ ├── test2017 |
||||
| | ├── stuffthingmaps |
||||
``` |
||||
|
||||
## Results and Models |
||||
|
||||
The results on COCO 2017val are shown in the below table. (results on test-dev are usually slightly higher than val) |
||||
|
||||
| Backbone | Style | Lr schd | Mem (GB) | Inf speed (fps) | box AP | mask AP | TTA box AP | TTA mask AP | Config | Download | |
||||
|:---------------:|:-------:|:-------:|:--------:|:---------------:|:------:|:-------:|:----------:|:-----------:|:------:|:------------:| |
||||
| R-50-FPN | pytorch | 1x | 7.0 | 6.2 | 43.5 | 39.2 | 44.8 | 40.9 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_r50_fpn_1x_coco.py) | [model](https://drive.google.com/file/d/179pcG-sNVDglJoZcQLsM8GWZnJhZtgWx/view?usp=sharing) \| [log](https://drive.google.com/file/d/1ZFS6QhFfxlOnDYPiGpSDP_Fzgb7iDGN3/view?usp=sharing) | |
||||
| R-50-FPN | pytorch | 20e | 7.0 | 6.2 | 44.5 | 40.0 | 45.8 | 41.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_r50_fpn_20e_coco.py) | [model](https://drive.google.com/file/d/1hCH4raiUXsnYQYmKc4ADeL5pNc2X72fJ/view?usp=sharing) \| [log](https://drive.google.com/file/d/1-LnkOXN8n5ojQW34H0qZ625cgrnWpqSX/view?usp=sharing) | |
||||
| R-101-FPN | pytorch | 20e | 8.9 | 5.8 | 45.8 | 40.9 | 47.3 | 42.7 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_r101_fpn_20e_coco.py) | [model](https://drive.google.com/file/d/1LcnegOX9YYZ7G8hqUa1F_Mp6PeT6_3jA/view?usp=sharing) \| [log](https://drive.google.com/file/d/1iRx-9GRgTaIDsz-we3DGwFVH22nbvCLa/view?usp=sharing) | |
||||
| X-101-64x4d-FPN | pytorch | 20e | 13.2 | 4.9 | 47.5 | 42.3 | 48.9 | 44.0 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_x101_64x4d_fpn_20e_coco.py) | [model](https://drive.google.com/file/d/1MDIoOBKwfXUdtEduz2BQm3MSgC05IlqZ/view?usp=sharing) \| [log](https://drive.google.com/file/d/1OsfQJ8gwtqIQ61k358yxY21sCvbUcRjs/view?usp=sharing) | |
||||
|
||||
### Notes |
||||
|
||||
- Training hyper-parameters are identical to those of [HTC](https://github.com/open-mmlab/mmdetection/tree/master/configs/htc). |
||||
- TTA means Test Time Augmentation, which applies horizonal flip and multi-scale testing. Refer to [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/scnet/scnet_r50_fpn_1x_coco.py). |
@ -0,0 +1,2 @@ |
||||
_base_ = './scnet_r50_fpn_20e_coco.py' |
||||
model = dict(pretrained='torchvision://resnet101', backbone=dict(depth=101)) |
@ -0,0 +1,136 @@ |
||||
_base_ = '../htc/htc_r50_fpn_1x_coco.py' |
||||
# model settings |
||||
model = dict( |
||||
type='SCNet', |
||||
roi_head=dict( |
||||
_delete_=True, |
||||
type='SCNetRoIHead', |
||||
num_stages=3, |
||||
stage_loss_weights=[1, 0.5, 0.25], |
||||
bbox_roi_extractor=dict( |
||||
type='SingleRoIExtractor', |
||||
roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), |
||||
out_channels=256, |
||||
featmap_strides=[4, 8, 16, 32]), |
||||
bbox_head=[ |
||||
dict( |
||||
type='SCNetBBoxHead', |
||||
num_shared_fcs=2, |
||||
in_channels=256, |
||||
fc_out_channels=1024, |
||||
roi_feat_size=7, |
||||
num_classes=80, |
||||
bbox_coder=dict( |
||||
type='DeltaXYWHBBoxCoder', |
||||
target_means=[0., 0., 0., 0.], |
||||
target_stds=[0.1, 0.1, 0.2, 0.2]), |
||||
reg_class_agnostic=True, |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', |
||||
use_sigmoid=False, |
||||
loss_weight=1.0), |
||||
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, |
||||
loss_weight=1.0)), |
||||
dict( |
||||
type='SCNetBBoxHead', |
||||
num_shared_fcs=2, |
||||
in_channels=256, |
||||
fc_out_channels=1024, |
||||
roi_feat_size=7, |
||||
num_classes=80, |
||||
bbox_coder=dict( |
||||
type='DeltaXYWHBBoxCoder', |
||||
target_means=[0., 0., 0., 0.], |
||||
target_stds=[0.05, 0.05, 0.1, 0.1]), |
||||
reg_class_agnostic=True, |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', |
||||
use_sigmoid=False, |
||||
loss_weight=1.0), |
||||
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, |
||||
loss_weight=1.0)), |
||||
dict( |
||||
type='SCNetBBoxHead', |
||||
num_shared_fcs=2, |
||||
in_channels=256, |
||||
fc_out_channels=1024, |
||||
roi_feat_size=7, |
||||
num_classes=80, |
||||
bbox_coder=dict( |
||||
type='DeltaXYWHBBoxCoder', |
||||
target_means=[0., 0., 0., 0.], |
||||
target_stds=[0.033, 0.033, 0.067, 0.067]), |
||||
reg_class_agnostic=True, |
||||
loss_cls=dict( |
||||
type='CrossEntropyLoss', |
||||
use_sigmoid=False, |
||||
loss_weight=1.0), |
||||
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) |
||||
], |
||||
mask_roi_extractor=dict( |
||||
type='SingleRoIExtractor', |
||||
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), |
||||
out_channels=256, |
||||
featmap_strides=[4, 8, 16, 32]), |
||||
mask_head=dict( |
||||
type='SCNetMaskHead', |
||||
num_convs=12, |
||||
in_channels=256, |
||||
conv_out_channels=256, |
||||
num_classes=80, |
||||
conv_to_res=True, |
||||
loss_mask=dict( |
||||
type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)), |
||||
semantic_roi_extractor=dict( |
||||
type='SingleRoIExtractor', |
||||
roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), |
||||
out_channels=256, |
||||
featmap_strides=[8]), |
||||
semantic_head=dict( |
||||
type='SCNetSemanticHead', |
||||
num_ins=5, |
||||
fusion_level=1, |
||||
num_convs=4, |
||||
in_channels=256, |
||||
conv_out_channels=256, |
||||
num_classes=183, |
||||
ignore_label=255, |
||||
loss_weight=0.2, |
||||
conv_to_res=True), |
||||
glbctx_head=dict( |
||||
type='GlobalContextHead', |
||||
num_convs=4, |
||||
in_channels=256, |
||||
conv_out_channels=256, |
||||
num_classes=81, |
||||
loss_weight=3.0, |
||||
conv_to_res=True), |
||||
feat_relay_head=dict( |
||||
type='FeatureRelayHead', |
||||
in_channels=1024, |
||||
out_conv_channels=256, |
||||
roi_feat_size=7, |
||||
scale_factor=2))) |
||||
|
||||
# uncomment below code to enable test time augmentations |
||||
# img_norm_cfg = dict( |
||||
# mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) |
||||
# test_pipeline = [ |
||||
# dict(type='LoadImageFromFile'), |
||||
# dict( |
||||
# type='MultiScaleFlipAug', |
||||
# img_scale=[(600, 900), (800, 1200), (1000, 1500), (1200, 1800), |
||||
# (1400, 2100)], |
||||
# flip=True, |
||||
# transforms=[ |
||||
# dict(type='Resize', keep_ratio=True), |
||||
# dict(type='RandomFlip', flip_ratio=0.5), |
||||
# dict(type='Normalize', **img_norm_cfg), |
||||
# dict(type='Pad', size_divisor=32), |
||||
# dict(type='ImageToTensor', keys=['img']), |
||||
# dict(type='Collect', keys=['img']), |
||||
# ]) |
||||
# ] |
||||
# data = dict( |
||||
# val=dict(pipeline=test_pipeline), |
||||
# test=dict(pipeline=test_pipeline)) |
@ -0,0 +1,4 @@ |
||||
_base_ = './scnet_r50_fpn_1x_coco.py' |
||||
# learning policy |
||||
lr_config = dict(step=[16, 19]) |
||||
total_epochs = 20 |
@ -0,0 +1,14 @@ |
||||
_base_ = './scnet_r50_fpn_20e_coco.py' |
||||
model = dict( |
||||
pretrained='open-mmlab://resnext101_64x4d', |
||||
backbone=dict( |
||||
type='ResNeXt', |
||||
depth=101, |
||||
groups=64, |
||||
base_width=4, |
||||
num_stages=4, |
||||
out_indices=(0, 1, 2, 3), |
||||
frozen_stages=1, |
||||
norm_cfg=dict(type='BN', requires_grad=True), |
||||
norm_eval=True, |
||||
style='pytorch')) |
@ -0,0 +1,3 @@ |
||||
_base_ = './scnet_x101_64x4d_fpn_20e_coco.py' |
||||
data = dict(samples_per_gpu=1, workers_per_gpu=1) |
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) |
@ -0,0 +1,10 @@ |
||||
from ..builder import DETECTORS |
||||
from .cascade_rcnn import CascadeRCNN |
||||
|
||||
|
||||
@DETECTORS.register_module() |
||||
class SCNet(CascadeRCNN): |
||||
"""Implementation of `SCNet <https://arxiv.org/abs/2012.10150>`_""" |
||||
|
||||
def __init__(self, **kwargs): |
||||
super(SCNet, self).__init__(**kwargs) |
@ -0,0 +1,76 @@ |
||||
from mmdet.models.builder import HEADS |
||||
from .convfc_bbox_head import ConvFCBBoxHead |
||||
|
||||
|
||||
@HEADS.register_module() |
||||
class SCNetBBoxHead(ConvFCBBoxHead): |
||||
"""BBox head for `SCNet <https://arxiv.org/abs/2012.10150>`_. |
||||
|
||||
This inherits ``ConvFCBBoxHead`` with modified forward() function, allow us |
||||
to get intermediate shared feature. |
||||
""" |
||||
|
||||
def _forward_shared(self, x): |
||||
"""Forward function for shared part.""" |
||||
if self.num_shared_convs > 0: |
||||
for conv in self.shared_convs: |
||||
x = conv(x) |
||||
|
||||
if self.num_shared_fcs > 0: |
||||
if self.with_avg_pool: |
||||
x = self.avg_pool(x) |
||||
|
||||
x = x.flatten(1) |
||||
|
||||
for fc in self.shared_fcs: |
||||
x = self.relu(fc(x)) |
||||
|
||||
return x |
||||
|
||||
def _forward_cls_reg(self, x): |
||||
"""Forward function for classification and regression parts.""" |
||||
x_cls = x |
||||
x_reg = x |
||||
|
||||
for conv in self.cls_convs: |
||||
x_cls = conv(x_cls) |
||||
if x_cls.dim() > 2: |
||||
if self.with_avg_pool: |
||||
x_cls = self.avg_pool(x_cls) |
||||
x_cls = x_cls.flatten(1) |
||||
for fc in self.cls_fcs: |
||||
x_cls = self.relu(fc(x_cls)) |
||||
|
||||
for conv in self.reg_convs: |
||||
x_reg = conv(x_reg) |
||||
if x_reg.dim() > 2: |
||||
if self.with_avg_pool: |
||||
x_reg = self.avg_pool(x_reg) |
||||
x_reg = x_reg.flatten(1) |
||||
for fc in self.reg_fcs: |
||||
x_reg = self.relu(fc(x_reg)) |
||||
|
||||
cls_score = self.fc_cls(x_cls) if self.with_cls else None |
||||
bbox_pred = self.fc_reg(x_reg) if self.with_reg else None |
||||
|
||||
return cls_score, bbox_pred |
||||
|
||||
def forward(self, x, return_shared_feat=False): |
||||
"""Forward function. |
||||
|
||||
Args: |
||||
x (Tensor): input features |
||||
return_shared_feat (bool): If True, return cls-reg-shared feature. |
||||
|
||||
Return: |
||||
out (tuple[Tensor]): contain ``cls_score`` and ``bbox_pred``, |
||||
if ``return_shared_feat`` is True, append ``x_shared`` to the |
||||
returned tuple. |
||||
""" |
||||
x_shared = self._forward_shared(x) |
||||
out = self._forward_cls_reg(x_shared) |
||||
|
||||
if return_shared_feat: |
||||
out += (x_shared, ) |
||||
|
||||
return out |
@ -1,12 +1,17 @@ |
||||
from .coarse_mask_head import CoarseMaskHead |
||||
from .fcn_mask_head import FCNMaskHead |
||||
from .feature_relay_head import FeatureRelayHead |
||||
from .fused_semantic_head import FusedSemanticHead |
||||
from .global_context_head import GlobalContextHead |
||||
from .grid_head import GridHead |
||||
from .htc_mask_head import HTCMaskHead |
||||
from .mask_point_head import MaskPointHead |
||||
from .maskiou_head import MaskIoUHead |
||||
from .scnet_mask_head import SCNetMaskHead |
||||
from .scnet_semantic_head import SCNetSemanticHead |
||||
|
||||
__all__ = [ |
||||
'FCNMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead', |
||||
'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead' |
||||
'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead', 'SCNetMaskHead', |
||||
'SCNetSemanticHead', 'GlobalContextHead', 'FeatureRelayHead' |
||||
] |
||||
|
@ -0,0 +1,55 @@ |
||||
import torch.nn as nn |
||||
from mmcv.cnn import kaiming_init |
||||
from mmcv.runner import auto_fp16 |
||||
|
||||
from mmdet.models.builder import HEADS |
||||
|
||||
|
||||
@HEADS.register_module() |
||||
class FeatureRelayHead(nn.Module): |
||||
"""Feature Relay Head used in SCNet https://arxiv.org/abs/2012.10150. |
||||
|
||||
Args: |
||||
in_channels (int, optional): number of input channels. Default: 256. |
||||
conv_out_channels (int, optional): number of output channels before |
||||
classification layer. Default: 256. |
||||
roi_feat_size (int, optional): roi feat size at box head. Default: 7. |
||||
scale_factor (int, optional): scale factor to match roi feat size |
||||
at mask head. Default: 2. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
in_channels=1024, |
||||
out_conv_channels=256, |
||||
roi_feat_size=7, |
||||
scale_factor=2): |
||||
super(FeatureRelayHead, self).__init__() |
||||
assert isinstance(roi_feat_size, int) |
||||
|
||||
self.in_channels = in_channels |
||||
self.out_conv_channels = out_conv_channels |
||||
self.roi_feat_size = roi_feat_size |
||||
self.out_channels = (roi_feat_size**2) * out_conv_channels |
||||
self.scale_factor = scale_factor |
||||
self.fp16_enabled = False |
||||
|
||||
self.fc = nn.Linear(self.in_channels, self.out_channels) |
||||
self.upsample = nn.Upsample( |
||||
scale_factor=scale_factor, mode='bilinear', align_corners=True) |
||||
|
||||
def init_weights(self): |
||||
"""Init weights for the head.""" |
||||
kaiming_init(self.fc) |
||||
|
||||
@auto_fp16() |
||||
def forward(self, x): |
||||
"""Forward function.""" |
||||
N, in_C = x.shape |
||||
if N > 0: |
||||
out_C = self.out_conv_channels |
||||
out_HW = self.roi_feat_size |
||||
x = self.fc(x) |
||||
x = x.reshape(N, out_C, out_HW, out_HW) |
||||
x = self.upsample(x) |
||||
return x |
||||
return None |
@ -0,0 +1,101 @@ |
||||
import torch.nn as nn |
||||
from mmcv.cnn import ConvModule |
||||
from mmcv.runner import auto_fp16, force_fp32 |
||||
|
||||
from mmdet.models.builder import HEADS |
||||
from mmdet.models.utils import ResLayer, SimplifiedBasicBlock |
||||
|
||||
|
||||
@HEADS.register_module() |
||||
class GlobalContextHead(nn.Module): |
||||
"""Global context head used in SCNet https://arxiv.org/abs/2012.10150. |
||||
|
||||
Args: |
||||
num_convs (int, optional): number of convolutional layer in GlbCtxHead. |
||||
Default: 4. |
||||
in_channels (int, optional): number of input channels. Default: 256. |
||||
conv_out_channels (int, optional): number of output channels before |
||||
classification layer. Default: 256. |
||||
loss_weight (float, optional): global context loss weight. Default: 1. |
||||
conv_cfg (dict, optional): config to init conv layer. Default: None. |
||||
norm_cfg (dict, optional): config to init norm layer. Default: None. |
||||
conv_to_res (bool, optional): if True, 2 convs will be grouped into |
||||
1 `SimplifiedBasicBlock` using a skip connection. Default: False. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
num_convs=4, |
||||
in_channels=256, |
||||
conv_out_channels=256, |
||||
num_classes=81, |
||||
loss_weight=1.0, |
||||
conv_cfg=None, |
||||
norm_cfg=None, |
||||
conv_to_res=False): |
||||
super(GlobalContextHead, self).__init__() |
||||
self.num_convs = num_convs |
||||
self.in_channels = in_channels |
||||
self.conv_out_channels = conv_out_channels |
||||
self.num_classes = num_classes |
||||
self.loss_weight = loss_weight |
||||
self.conv_cfg = conv_cfg |
||||
self.norm_cfg = norm_cfg |
||||
self.conv_to_res = conv_to_res |
||||
self.fp16_enabled = False |
||||
|
||||
if self.conv_to_res: |
||||
num_res_blocks = num_convs // 2 |
||||
self.convs = ResLayer( |
||||
SimplifiedBasicBlock, |
||||
in_channels, |
||||
self.conv_out_channels, |
||||
num_res_blocks, |
||||
conv_cfg=self.conv_cfg, |
||||
norm_cfg=self.norm_cfg) |
||||
self.num_convs = num_res_blocks |
||||
else: |
||||
self.convs = nn.ModuleList() |
||||
for i in range(self.num_convs): |
||||
in_channels = self.in_channels if i == 0 else conv_out_channels |
||||
self.convs.append( |
||||
ConvModule( |
||||
in_channels, |
||||
conv_out_channels, |
||||
3, |
||||
padding=1, |
||||
conv_cfg=self.conv_cfg, |
||||
norm_cfg=self.norm_cfg)) |
||||
|
||||
self.pool = nn.AdaptiveAvgPool2d(1) |
||||
self.fc = nn.Linear(conv_out_channels, num_classes) |
||||
|
||||
self.criterion = nn.BCEWithLogitsLoss() |
||||
|
||||
def init_weights(self): |
||||
"""Init weights for the head.""" |
||||
nn.init.normal_(self.fc.weight, 0, 0.01) |
||||
nn.init.constant_(self.fc.bias, 0) |
||||
|
||||
@auto_fp16() |
||||
def forward(self, feats): |
||||
"""Forward function.""" |
||||
x = feats[-1] |
||||
for i in range(self.num_convs): |
||||
x = self.convs[i](x) |
||||
x = self.pool(x) |
||||
|
||||
# multi-class prediction |
||||
mc_pred = x.reshape(x.size(0), -1) |
||||
mc_pred = self.fc(mc_pred) |
||||
|
||||
return mc_pred, x |
||||
|
||||
@force_fp32(apply_to=('pred', )) |
||||
def loss(self, pred, labels): |
||||
"""Loss function.""" |
||||
labels = [lbl.unique() for lbl in labels] |
||||
targets = pred.new_zeros(pred.size()) |
||||
for i, label in enumerate(labels): |
||||
targets[i, label] = 1.0 |
||||
loss = self.loss_weight * self.criterion(pred, targets) |
||||
return loss |
@ -0,0 +1,27 @@ |
||||
from mmdet.models.builder import HEADS |
||||
from mmdet.models.utils import ResLayer, SimplifiedBasicBlock |
||||
from .fcn_mask_head import FCNMaskHead |
||||
|
||||
|
||||
@HEADS.register_module() |
||||
class SCNetMaskHead(FCNMaskHead): |
||||
"""Mask head for `SCNet <https://arxiv.org/abs/2012.10150>`_. |
||||
|
||||
Args: |
||||
conv_to_res (bool, optional): if True, change the conv layers to |
||||
``SimplifiedBasicBlock``. |
||||
""" |
||||
|
||||
def __init__(self, conv_to_res=True, **kwargs): |
||||
super(SCNetMaskHead, self).__init__(**kwargs) |
||||
self.conv_to_res = conv_to_res |
||||
if conv_to_res: |
||||
assert self.conv_kernel_size == 3 |
||||
self.num_res_blocks = self.num_convs // 2 |
||||
self.convs = ResLayer( |
||||
SimplifiedBasicBlock, |
||||
self.in_channels, |
||||
self.conv_out_channels, |
||||
self.num_res_blocks, |
||||
conv_cfg=self.conv_cfg, |
||||
norm_cfg=self.norm_cfg) |
@ -0,0 +1,27 @@ |
||||
from mmdet.models.builder import HEADS |
||||
from mmdet.models.utils import ResLayer, SimplifiedBasicBlock |
||||
from .fused_semantic_head import FusedSemanticHead |
||||
|
||||
|
||||
@HEADS.register_module() |
||||
class SCNetSemanticHead(FusedSemanticHead): |
||||
"""Mask head for `SCNet <https://arxiv.org/abs/2012.10150>`_. |
||||
|
||||
Args: |
||||
conv_to_res (bool, optional): if True, change the conv layers to |
||||
``SimplifiedBasicBlock``. |
||||
""" |
||||
|
||||
def __init__(self, conv_to_res=True, **kwargs): |
||||
super(SCNetSemanticHead, self).__init__(**kwargs) |
||||
self.conv_to_res = conv_to_res |
||||
if self.conv_to_res: |
||||
num_res_blocks = self.num_convs // 2 |
||||
self.convs = ResLayer( |
||||
SimplifiedBasicBlock, |
||||
self.in_channels, |
||||
self.conv_out_channels, |
||||
num_res_blocks, |
||||
conv_cfg=self.conv_cfg, |
||||
norm_cfg=self.norm_cfg) |
||||
self.num_convs = num_res_blocks |
@ -0,0 +1,582 @@ |
||||
import torch |
||||
import torch.nn.functional as F |
||||
|
||||
from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes, |
||||
merge_aug_masks, multiclass_nms) |
||||
from ..builder import HEADS, build_head, build_roi_extractor |
||||
from .cascade_roi_head import CascadeRoIHead |
||||
|
||||
|
||||
@HEADS.register_module() |
||||
class SCNetRoIHead(CascadeRoIHead): |
||||
"""RoIHead for `SCNet <https://arxiv.org/abs/2012.10150>`_. |
||||
|
||||
Args: |
||||
num_stages (int): number of cascade stages. |
||||
stage_loss_weights (list): loss weight of cascade stages. |
||||
semantic_roi_extractor (dict): config to init semantic roi extractor. |
||||
semantic_head (dict): config to init semantic head. |
||||
feat_relay_head (dict): config to init feature_relay_head. |
||||
glbctx_head (dict): config to init global context head. |
||||
""" |
||||
|
||||
def __init__(self, |
||||
num_stages, |
||||
stage_loss_weights, |
||||
semantic_roi_extractor=None, |
||||
semantic_head=None, |
||||
feat_relay_head=None, |
||||
glbctx_head=None, |
||||
**kwargs): |
||||
super(SCNetRoIHead, self).__init__(num_stages, stage_loss_weights, |
||||
**kwargs) |
||||
assert self.with_bbox and self.with_mask |
||||
assert not self.with_shared_head # shared head is not supported |
||||
|
||||
if semantic_head is not None: |
||||
self.semantic_roi_extractor = build_roi_extractor( |
||||
semantic_roi_extractor) |
||||
self.semantic_head = build_head(semantic_head) |
||||
|
||||
if feat_relay_head is not None: |
||||
self.feat_relay_head = build_head(feat_relay_head) |
||||
|
||||
if glbctx_head is not None: |
||||
self.glbctx_head = build_head(glbctx_head) |
||||
|
||||
def init_mask_head(self, mask_roi_extractor, mask_head): |
||||
"""Initialize ``mask_head``""" |
||||
if mask_roi_extractor is not None: |
||||
self.mask_roi_extractor = build_roi_extractor(mask_roi_extractor) |
||||
self.mask_head = build_head(mask_head) |
||||
|
||||
def init_weights(self, pretrained): |
||||
"""Initialize the weights in head. |
||||
|
||||
Args: |
||||
pretrained (str, optional): Path to pre-trained weights. |
||||
Defaults to None. |
||||
""" |
||||
for i in range(self.num_stages): |
||||
if self.with_bbox: |
||||
self.bbox_roi_extractor[i].init_weights() |
||||
self.bbox_head[i].init_weights() |
||||
if self.with_mask: |
||||
self.mask_roi_extractor.init_weights() |
||||
self.mask_head.init_weights() |
||||
if self.with_semantic: |
||||
self.semantic_head.init_weights() |
||||
if self.with_glbctx: |
||||
self.glbctx_head.init_weights() |
||||
if self.with_feat_relay: |
||||
self.feat_relay_head.init_weights() |
||||
|
||||
@property |
||||
def with_semantic(self): |
||||
"""bool: whether the head has semantic head""" |
||||
return hasattr(self, |
||||
'semantic_head') and self.semantic_head is not None |
||||
|
||||
@property |
||||
def with_feat_relay(self): |
||||
"""bool: whether the head has feature relay head""" |
||||
return (hasattr(self, 'feat_relay_head') |
||||
and self.feat_relay_head is not None) |
||||
|
||||
@property |
||||
def with_glbctx(self): |
||||
"""bool: whether the head has global context head""" |
||||
return hasattr(self, 'glbctx_head') and self.glbctx_head is not None |
||||
|
||||
def _fuse_glbctx(self, roi_feats, glbctx_feat, rois): |
||||
"""Fuse global context feats with roi feats.""" |
||||
assert roi_feats.size(0) == rois.size(0) |
||||
img_inds = torch.unique(rois[:, 0].cpu(), sorted=True).long() |
||||
fused_feats = torch.zeros_like(roi_feats) |
||||
for img_id in img_inds: |
||||
inds = (rois[:, 0] == img_id.item()) |
||||
fused_feats[inds] = roi_feats[inds] + glbctx_feat[img_id] |
||||
return fused_feats |
||||
|
||||
def _slice_pos_feats(self, feats, sampling_results): |
||||
"""Get features from pos rois.""" |
||||
num_rois = [res.bboxes.size(0) for res in sampling_results] |
||||
num_pos_rois = [res.pos_bboxes.size(0) for res in sampling_results] |
||||
inds = torch.zeros(sum(num_rois), dtype=torch.bool) |
||||
start = 0 |
||||
for i in range(len(num_rois)): |
||||
start = 0 if i == 0 else start + num_rois[i - 1] |
||||
stop = start + num_pos_rois[i] |
||||
inds[start:stop] = 1 |
||||
sliced_feats = feats[inds] |
||||
return sliced_feats |
||||
|
||||
def _bbox_forward(self, |
||||
stage, |
||||
x, |
||||
rois, |
||||
semantic_feat=None, |
||||
glbctx_feat=None): |
||||
"""Box head forward function used in both training and testing.""" |
||||
bbox_roi_extractor = self.bbox_roi_extractor[stage] |
||||
bbox_head = self.bbox_head[stage] |
||||
bbox_feats = bbox_roi_extractor( |
||||
x[:len(bbox_roi_extractor.featmap_strides)], rois) |
||||
if self.with_semantic and semantic_feat is not None: |
||||
bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat], |
||||
rois) |
||||
if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]: |
||||
bbox_semantic_feat = F.adaptive_avg_pool2d( |
||||
bbox_semantic_feat, bbox_feats.shape[-2:]) |
||||
bbox_feats += bbox_semantic_feat |
||||
if self.with_glbctx and glbctx_feat is not None: |
||||
bbox_feats = self._fuse_glbctx(bbox_feats, glbctx_feat, rois) |
||||
cls_score, bbox_pred, relayed_feat = bbox_head( |
||||
bbox_feats, return_shared_feat=True) |
||||
|
||||
bbox_results = dict( |
||||
cls_score=cls_score, |
||||
bbox_pred=bbox_pred, |
||||
relayed_feat=relayed_feat) |
||||
return bbox_results |
||||
|
||||
def _mask_forward(self, |
||||
x, |
||||
rois, |
||||
semantic_feat=None, |
||||
glbctx_feat=None, |
||||
relayed_feat=None): |
||||
"""Mask head forward function used in both training and testing.""" |
||||
mask_feats = self.mask_roi_extractor( |
||||
x[:self.mask_roi_extractor.num_inputs], rois) |
||||
if self.with_semantic and semantic_feat is not None: |
||||
mask_semantic_feat = self.semantic_roi_extractor([semantic_feat], |
||||
rois) |
||||
if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]: |
||||
mask_semantic_feat = F.adaptive_avg_pool2d( |
||||
mask_semantic_feat, mask_feats.shape[-2:]) |
||||
mask_feats += mask_semantic_feat |
||||
if self.with_glbctx and glbctx_feat is not None: |
||||
mask_feats = self._fuse_glbctx(mask_feats, glbctx_feat, rois) |
||||
if self.with_feat_relay and relayed_feat is not None: |
||||
mask_feats = mask_feats + relayed_feat |
||||
mask_pred = self.mask_head(mask_feats) |
||||
mask_results = dict(mask_pred=mask_pred) |
||||
|
||||
return mask_results |
||||
|
||||
def _bbox_forward_train(self, |
||||
stage, |
||||
x, |
||||
sampling_results, |
||||
gt_bboxes, |
||||
gt_labels, |
||||
rcnn_train_cfg, |
||||
semantic_feat=None, |
||||
glbctx_feat=None): |
||||
"""Run forward function and calculate loss for box head in training.""" |
||||
bbox_head = self.bbox_head[stage] |
||||
rois = bbox2roi([res.bboxes for res in sampling_results]) |
||||
bbox_results = self._bbox_forward( |
||||
stage, |
||||
x, |
||||
rois, |
||||
semantic_feat=semantic_feat, |
||||
glbctx_feat=glbctx_feat) |
||||
|
||||
bbox_targets = bbox_head.get_targets(sampling_results, gt_bboxes, |
||||
gt_labels, rcnn_train_cfg) |
||||
loss_bbox = bbox_head.loss(bbox_results['cls_score'], |
||||
bbox_results['bbox_pred'], rois, |
||||
*bbox_targets) |
||||
|
||||
bbox_results.update( |
||||
loss_bbox=loss_bbox, rois=rois, bbox_targets=bbox_targets) |
||||
return bbox_results |
||||
|
||||
def _mask_forward_train(self, |
||||
x, |
||||
sampling_results, |
||||
gt_masks, |
||||
rcnn_train_cfg, |
||||
semantic_feat=None, |
||||
glbctx_feat=None, |
||||
relayed_feat=None): |
||||
"""Run forward function and calculate loss for mask head in |
||||
training.""" |
||||
pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results]) |
||||
mask_results = self._mask_forward( |
||||
x, |
||||
pos_rois, |
||||
semantic_feat=semantic_feat, |
||||
glbctx_feat=glbctx_feat, |
||||
relayed_feat=relayed_feat) |
||||
|
||||
mask_targets = self.mask_head.get_targets(sampling_results, gt_masks, |
||||
rcnn_train_cfg) |
||||
pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results]) |
||||
loss_mask = self.mask_head.loss(mask_results['mask_pred'], |
||||
mask_targets, pos_labels) |
||||
|
||||
mask_results = loss_mask |
||||
return mask_results |
||||
|
||||
def forward_train(self, |
||||
x, |
||||
img_metas, |
||||
proposal_list, |
||||
gt_bboxes, |
||||
gt_labels, |
||||
gt_bboxes_ignore=None, |
||||
gt_masks=None, |
||||
gt_semantic_seg=None): |
||||
""" |
||||
Args: |
||||
x (list[Tensor]): list of multi-level img features. |
||||
|
||||
img_metas (list[dict]): list of image info dict where each dict |
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain |
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. |
||||
For details on the values of these keys see |
||||
`mmdet/datasets/pipelines/formatting.py:Collect`. |
||||
|
||||
proposal_list (list[Tensors]): list of region proposals. |
||||
|
||||
gt_bboxes (list[Tensor]): Ground truth bboxes for each image with |
||||
shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. |
||||
|
||||
gt_labels (list[Tensor]): class indices corresponding to each box |
||||
|
||||
gt_bboxes_ignore (None, list[Tensor]): specify which bounding |
||||
boxes can be ignored when computing the loss. |
||||
|
||||
gt_masks (None, Tensor) : true segmentation masks for each box |
||||
used if the architecture supports a segmentation task. |
||||
|
||||
gt_semantic_seg (None, list[Tensor]): semantic segmentation masks |
||||
used if the architecture supports semantic segmentation task. |
||||
|
||||
Returns: |
||||
dict[str, Tensor]: a dictionary of loss components |
||||
""" |
||||
losses = dict() |
||||
|
||||
# semantic segmentation branch |
||||
if self.with_semantic: |
||||
semantic_pred, semantic_feat = self.semantic_head(x) |
||||
loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_seg) |
||||
losses['loss_semantic_seg'] = loss_seg |
||||
else: |
||||
semantic_feat = None |
||||
|
||||
# global context branch |
||||
if self.with_glbctx: |
||||
mc_pred, glbctx_feat = self.glbctx_head(x) |
||||
loss_glbctx = self.glbctx_head.loss(mc_pred, gt_labels) |
||||
losses['loss_glbctx'] = loss_glbctx |
||||
else: |
||||
glbctx_feat = None |
||||
|
||||
for i in range(self.num_stages): |
||||
self.current_stage = i |
||||
rcnn_train_cfg = self.train_cfg[i] |
||||
lw = self.stage_loss_weights[i] |
||||
|
||||
# assign gts and sample proposals |
||||
sampling_results = [] |
||||
bbox_assigner = self.bbox_assigner[i] |
||||
bbox_sampler = self.bbox_sampler[i] |
||||
num_imgs = len(img_metas) |
||||
if gt_bboxes_ignore is None: |
||||
gt_bboxes_ignore = [None for _ in range(num_imgs)] |
||||
|
||||
for j in range(num_imgs): |
||||
assign_result = bbox_assigner.assign(proposal_list[j], |
||||
gt_bboxes[j], |
||||
gt_bboxes_ignore[j], |
||||
gt_labels[j]) |
||||
sampling_result = bbox_sampler.sample( |
||||
assign_result, |
||||
proposal_list[j], |
||||
gt_bboxes[j], |
||||
gt_labels[j], |
||||
feats=[lvl_feat[j][None] for lvl_feat in x]) |
||||
sampling_results.append(sampling_result) |
||||
|
||||
bbox_results = \ |
||||
self._bbox_forward_train( |
||||
i, x, sampling_results, gt_bboxes, gt_labels, |
||||
rcnn_train_cfg, semantic_feat, glbctx_feat) |
||||
roi_labels = bbox_results['bbox_targets'][0] |
||||
|
||||
for name, value in bbox_results['loss_bbox'].items(): |
||||
losses[f's{i}.{name}'] = ( |
||||
value * lw if 'loss' in name else value) |
||||
|
||||
# refine boxes |
||||
if i < self.num_stages - 1: |
||||
pos_is_gts = [res.pos_is_gt for res in sampling_results] |
||||
with torch.no_grad(): |
||||
proposal_list = self.bbox_head[i].refine_bboxes( |
||||
bbox_results['rois'], roi_labels, |
||||
bbox_results['bbox_pred'], pos_is_gts, img_metas) |
||||
|
||||
if self.with_feat_relay: |
||||
relayed_feat = self._slice_pos_feats(bbox_results['relayed_feat'], |
||||
sampling_results) |
||||
relayed_feat = self.feat_relay_head(relayed_feat) |
||||
else: |
||||
relayed_feat = None |
||||
|
||||
mask_results = self._mask_forward_train(x, sampling_results, gt_masks, |
||||
rcnn_train_cfg, semantic_feat, |
||||
glbctx_feat, relayed_feat) |
||||
mask_lw = sum(self.stage_loss_weights) |
||||
losses['loss_mask'] = mask_lw * mask_results['loss_mask'] |
||||
|
||||
return losses |
||||
|
||||
def simple_test(self, x, proposal_list, img_metas, rescale=False): |
||||
"""Test without augmentation.""" |
||||
if self.with_semantic: |
||||
_, semantic_feat = self.semantic_head(x) |
||||
else: |
||||
semantic_feat = None |
||||
|
||||
if self.with_glbctx: |
||||
mc_pred, glbctx_feat = self.glbctx_head(x) |
||||
else: |
||||
glbctx_feat = None |
||||
|
||||
num_imgs = len(proposal_list) |
||||
img_shapes = tuple(meta['img_shape'] for meta in img_metas) |
||||
ori_shapes = tuple(meta['ori_shape'] for meta in img_metas) |
||||
scale_factors = tuple(meta['scale_factor'] for meta in img_metas) |
||||
|
||||
# "ms" in variable names means multi-stage |
||||
ms_scores = [] |
||||
rcnn_test_cfg = self.test_cfg |
||||
|
||||
rois = bbox2roi(proposal_list) |
||||
for i in range(self.num_stages): |
||||
bbox_head = self.bbox_head[i] |
||||
bbox_results = self._bbox_forward( |
||||
i, |
||||
x, |
||||
rois, |
||||
semantic_feat=semantic_feat, |
||||
glbctx_feat=glbctx_feat) |
||||
# split batch bbox prediction back to each image |
||||
cls_score = bbox_results['cls_score'] |
||||
bbox_pred = bbox_results['bbox_pred'] |
||||
num_proposals_per_img = tuple(len(p) for p in proposal_list) |
||||
rois = rois.split(num_proposals_per_img, 0) |
||||
cls_score = cls_score.split(num_proposals_per_img, 0) |
||||
bbox_pred = bbox_pred.split(num_proposals_per_img, 0) |
||||
ms_scores.append(cls_score) |
||||
|
||||
if i < self.num_stages - 1: |
||||
bbox_label = [s[:, :-1].argmax(dim=1) for s in cls_score] |
||||
rois = torch.cat([ |
||||
bbox_head.regress_by_class(rois[i], bbox_label[i], |
||||
bbox_pred[i], img_metas[i]) |
||||
for i in range(num_imgs) |
||||
]) |
||||
|
||||
# average scores of each image by stages |
||||
cls_score = [ |
||||
sum([score[i] for score in ms_scores]) / float(len(ms_scores)) |
||||
for i in range(num_imgs) |
||||
] |
||||
|
||||
# apply bbox post-processing to each image individually |
||||
det_bboxes = [] |
||||
det_labels = [] |
||||
for i in range(num_imgs): |
||||
det_bbox, det_label = self.bbox_head[-1].get_bboxes( |
||||
rois[i], |
||||
cls_score[i], |
||||
bbox_pred[i], |
||||
img_shapes[i], |
||||
scale_factors[i], |
||||
rescale=rescale, |
||||
cfg=rcnn_test_cfg) |
||||
det_bboxes.append(det_bbox) |
||||
det_labels.append(det_label) |
||||
det_bbox_results = [ |
||||
bbox2result(det_bboxes[i], det_labels[i], |
||||
self.bbox_head[-1].num_classes) |
||||
for i in range(num_imgs) |
||||
] |
||||
|
||||
if self.with_mask: |
||||
if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes): |
||||
mask_classes = self.mask_head.num_classes |
||||
det_segm_results = [[[] for _ in range(mask_classes)] |
||||
for _ in range(num_imgs)] |
||||
else: |
||||
if rescale and not isinstance(scale_factors[0], float): |
||||
scale_factors = [ |
||||
torch.from_numpy(scale_factor).to(det_bboxes[0].device) |
||||
for scale_factor in scale_factors |
||||
] |
||||
_bboxes = [ |
||||
det_bboxes[i][:, :4] * |
||||
scale_factors[i] if rescale else det_bboxes[i] |
||||
for i in range(num_imgs) |
||||
] |
||||
mask_rois = bbox2roi(_bboxes) |
||||
|
||||
# get relay feature on mask_rois |
||||
bbox_results = self._bbox_forward( |
||||
-1, |
||||
x, |
||||
mask_rois, |
||||
semantic_feat=semantic_feat, |
||||
glbctx_feat=glbctx_feat) |
||||
relayed_feat = bbox_results['relayed_feat'] |
||||
relayed_feat = self.feat_relay_head(relayed_feat) |
||||
|
||||
mask_results = self._mask_forward( |
||||
x, |
||||
mask_rois, |
||||
semantic_feat=semantic_feat, |
||||
glbctx_feat=glbctx_feat, |
||||
relayed_feat=relayed_feat) |
||||
mask_pred = mask_results['mask_pred'] |
||||
|
||||
# split batch mask prediction back to each image |
||||
num_bbox_per_img = tuple(len(_bbox) for _bbox in _bboxes) |
||||
mask_preds = mask_pred.split(num_bbox_per_img, 0) |
||||
|
||||
# apply mask post-processing to each image individually |
||||
det_segm_results = [] |
||||
for i in range(num_imgs): |
||||
if det_bboxes[i].shape[0] == 0: |
||||
det_segm_results.append( |
||||
[[] for _ in range(self.mask_head.num_classes)]) |
||||
else: |
||||
segm_result = self.mask_head.get_seg_masks( |
||||
mask_preds[i], _bboxes[i], det_labels[i], |
||||
self.test_cfg, ori_shapes[i], scale_factors[i], |
||||
rescale) |
||||
det_segm_results.append(segm_result) |
||||
|
||||
# return results |
||||
if self.with_mask: |
||||
return list(zip(det_bbox_results, det_segm_results)) |
||||
else: |
||||
return det_bbox_results |
||||
|
||||
def aug_test(self, img_feats, proposal_list, img_metas, rescale=False): |
||||
if self.with_semantic: |
||||
semantic_feats = [ |
||||
self.semantic_head(feat)[1] for feat in img_feats |
||||
] |
||||
else: |
||||
semantic_feats = [None] * len(img_metas) |
||||
|
||||
if self.with_glbctx: |
||||
glbctx_feats = [self.glbctx_head(feat)[1] for feat in img_feats] |
||||
else: |
||||
glbctx_feats = [None] * len(img_metas) |
||||
|
||||
rcnn_test_cfg = self.test_cfg |
||||
aug_bboxes = [] |
||||
aug_scores = [] |
||||
for x, img_meta, semantic_feat, glbctx_feat in zip( |
||||
img_feats, img_metas, semantic_feats, glbctx_feats): |
||||
# only one image in the batch |
||||
img_shape = img_meta[0]['img_shape'] |
||||
scale_factor = img_meta[0]['scale_factor'] |
||||
flip = img_meta[0]['flip'] |
||||
|
||||
proposals = bbox_mapping(proposal_list[0][:, :4], img_shape, |
||||
scale_factor, flip) |
||||
# "ms" in variable names means multi-stage |
||||
ms_scores = [] |
||||
|
||||
rois = bbox2roi([proposals]) |
||||
for i in range(self.num_stages): |
||||
bbox_head = self.bbox_head[i] |
||||
bbox_results = self._bbox_forward( |
||||
i, |
||||
x, |
||||
rois, |
||||
semantic_feat=semantic_feat, |
||||
glbctx_feat=glbctx_feat) |
||||
ms_scores.append(bbox_results['cls_score']) |
||||
if i < self.num_stages - 1: |
||||
bbox_label = bbox_results['cls_score'].argmax(dim=1) |
||||
rois = bbox_head.regress_by_class( |
||||
rois, bbox_label, bbox_results['bbox_pred'], |
||||
img_meta[0]) |
||||
|
||||
cls_score = sum(ms_scores) / float(len(ms_scores)) |
||||
bboxes, scores = self.bbox_head[-1].get_bboxes( |
||||
rois, |
||||
cls_score, |
||||
bbox_results['bbox_pred'], |
||||
img_shape, |
||||
scale_factor, |
||||
rescale=False, |
||||
cfg=None) |
||||
aug_bboxes.append(bboxes) |
||||
aug_scores.append(scores) |
||||
|
||||
# after merging, bboxes will be rescaled to the original image size |
||||
merged_bboxes, merged_scores = merge_aug_bboxes( |
||||
aug_bboxes, aug_scores, img_metas, rcnn_test_cfg) |
||||
det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores, |
||||
rcnn_test_cfg.score_thr, |
||||
rcnn_test_cfg.nms, |
||||
rcnn_test_cfg.max_per_img) |
||||
|
||||
det_bbox_results = bbox2result(det_bboxes, det_labels, |
||||
self.bbox_head[-1].num_classes) |
||||
|
||||
if self.with_mask: |
||||
if det_bboxes.shape[0] == 0: |
||||
det_segm_results = [[] |
||||
for _ in range(self.mask_head.num_classes)] |
||||
else: |
||||
aug_masks = [] |
||||
for x, img_meta, semantic_feat, glbctx_feat in zip( |
||||
img_feats, img_metas, semantic_feats, glbctx_feats): |
||||
img_shape = img_meta[0]['img_shape'] |
||||
scale_factor = img_meta[0]['scale_factor'] |
||||
flip = img_meta[0]['flip'] |
||||
_bboxes = bbox_mapping(det_bboxes[:, :4], img_shape, |
||||
scale_factor, flip) |
||||
mask_rois = bbox2roi([_bboxes]) |
||||
# get relay feature on mask_rois |
||||
bbox_results = self._bbox_forward( |
||||
-1, |
||||
x, |
||||
mask_rois, |
||||
semantic_feat=semantic_feat, |
||||
glbctx_feat=glbctx_feat) |
||||
relayed_feat = bbox_results['relayed_feat'] |
||||
relayed_feat = self.feat_relay_head(relayed_feat) |
||||
mask_results = self._mask_forward( |
||||
x, |
||||
mask_rois, |
||||
semantic_feat=semantic_feat, |
||||
glbctx_feat=glbctx_feat, |
||||
relayed_feat=relayed_feat) |
||||
mask_pred = mask_results['mask_pred'] |
||||
aug_masks.append(mask_pred.sigmoid().cpu().numpy()) |
||||
merged_masks = merge_aug_masks(aug_masks, img_metas, |
||||
self.test_cfg) |
||||
ori_shape = img_metas[0][0]['ori_shape'] |
||||
det_segm_results = self.mask_head.get_seg_masks( |
||||
merged_masks, |
||||
det_bboxes, |
||||
det_labels, |
||||
rcnn_test_cfg, |
||||
ori_shape, |
||||
scale_factor=1.0, |
||||
rescale=False) |
||||
return [(det_bbox_results, det_segm_results)] |
||||
else: |
||||
return [det_bbox_results] |
Loading…
Reference in new issue