OpenMMLab Detection Toolbox and Benchmark https://mmdetection.readthedocs.io/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

155 lines
6.5 KiB

# Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn
from mmcv.runner import ModuleList
from ..builder import HEADS
from ..utils import ConvUpsample
from .base_semantic_head import BaseSemanticHead
@HEADS.register_module()
class PanopticFPNHead(BaseSemanticHead):
"""PanopticFPNHead used in Panoptic FPN.
In this head, the number of output channels is ``num_stuff_classes
+ 1``, including all stuff classes and one thing class. The stuff
classes will be reset from ``0`` to ``num_stuff_classes - 1``, the
thing classes will be merged to ``num_stuff_classes``-th channel.
Arg:
num_things_classes (int): Number of thing classes. Default: 80.
num_stuff_classes (int): Number of stuff classes. Default: 53.
num_classes (int): Number of classes, including all stuff
classes and one thing class. This argument is deprecated,
please use ``num_things_classes`` and ``num_stuff_classes``.
The module will automatically infer the num_classes by
``num_stuff_classes + 1``.
in_channels (int): Number of channels in the input feature
map.
inner_channels (int): Number of channels in inner features.
start_level (int): The start level of the input features
used in PanopticFPN.
end_level (int): The end level of the used features, the
``end_level``-th layer will not be used.
fg_range (tuple): Range of the foreground classes. It starts
from ``0`` to ``num_things_classes-1``. Deprecated, please use
``num_things_classes`` directly.
bg_range (tuple): Range of the background classes. It starts
from ``num_things_classes`` to ``num_things_classes +
num_stuff_classes - 1``. Deprecated, please use
``num_stuff_classes`` and ``num_things_classes`` directly.
conv_cfg (dict): Dictionary to construct and config
conv layer. Default: None.
norm_cfg (dict): Dictionary to construct and config norm layer.
Use ``GN`` by default.
init_cfg (dict or list[dict], optional): Initialization config dict.
loss_seg (dict): the loss of the semantic head.
"""
def __init__(self,
num_things_classes=80,
num_stuff_classes=53,
num_classes=None,
in_channels=256,
inner_channels=128,
start_level=0,
end_level=4,
fg_range=None,
bg_range=None,
conv_cfg=None,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
init_cfg=None,
loss_seg=dict(
type='CrossEntropyLoss', ignore_index=-1,
loss_weight=1.0)):
if num_classes is not None:
warnings.warn(
'`num_classes` is deprecated now, please set '
'`num_stuff_classes` directly, the `num_classes` will be '
'set to `num_stuff_classes + 1`')
# num_classes = num_stuff_classes + 1 for PanopticFPN.
assert num_classes == num_stuff_classes + 1
super(PanopticFPNHead, self).__init__(num_stuff_classes + 1, init_cfg,
loss_seg)
self.num_things_classes = num_things_classes
self.num_stuff_classes = num_stuff_classes
if fg_range is not None and bg_range is not None:
self.fg_range = fg_range
self.bg_range = bg_range
self.num_things_classes = fg_range[1] - fg_range[0] + 1
self.num_stuff_classes = bg_range[1] - bg_range[0] + 1
warnings.warn(
'`fg_range` and `bg_range` are deprecated now, '
f'please use `num_things_classes`={self.num_things_classes} '
f'and `num_stuff_classes`={self.num_stuff_classes} instead.')
# Used feature layers are [start_level, end_level)
self.start_level = start_level
self.end_level = end_level
self.num_stages = end_level - start_level
self.inner_channels = inner_channels
self.conv_upsample_layers = ModuleList()
for i in range(start_level, end_level):
self.conv_upsample_layers.append(
ConvUpsample(
in_channels,
inner_channels,
num_layers=i if i > 0 else 1,
num_upsample=i if i > 0 else 0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
))
self.conv_logits = nn.Conv2d(inner_channels, self.num_classes, 1)
def _set_things_to_void(self, gt_semantic_seg):
"""Merge thing classes to one class.
In PanopticFPN, the background labels will be reset from `0` to
`self.num_stuff_classes-1`, the foreground labels will be merged to
`self.num_stuff_classes`-th channel.
"""
gt_semantic_seg = gt_semantic_seg.int()
fg_mask = gt_semantic_seg < self.num_things_classes
bg_mask = (gt_semantic_seg >= self.num_things_classes) * (
gt_semantic_seg < self.num_things_classes + self.num_stuff_classes)
new_gt_seg = torch.clone(gt_semantic_seg)
new_gt_seg = torch.where(bg_mask,
gt_semantic_seg - self.num_things_classes,
new_gt_seg)
new_gt_seg = torch.where(fg_mask,
fg_mask.int() * self.num_stuff_classes,
new_gt_seg)
return new_gt_seg
def loss(self, seg_preds, gt_semantic_seg):
"""The loss of PanopticFPN head.
Things classes will be merged to one class in PanopticFPN.
"""
gt_semantic_seg = self._set_things_to_void(gt_semantic_seg)
return super().loss(seg_preds, gt_semantic_seg)
def init_weights(self):
super().init_weights()
nn.init.normal_(self.conv_logits.weight.data, 0, 0.01)
self.conv_logits.bias.data.zero_()
def forward(self, x):
# the number of subnets must be not more than
# the length of features.
assert self.num_stages <= len(x)
feats = []
for i, layer in enumerate(self.conv_upsample_layers):
f = layer(x[self.start_level + i])
feats.append(f)
feats = torch.sum(torch.stack(feats, dim=0), dim=0)
seg_preds = self.conv_logits(feats)
out = dict(seg_preds=seg_preds, feats=feats)
return out