Add base one-stage segmentor (#5904)

pull/6172/head
Shilong Zhang 4 years ago committed by GitHub
parent 9696414415
commit d53fbbc587
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 116
      mmdet/models/dense_heads/base_mask_head.py
  2. 363
      mmdet/models/detectors/single_stage_instance_seg.py

@ -0,0 +1,116 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from mmcv.runner import BaseModule
class BaseMaskHead(BaseModule, metaclass=ABCMeta):
"""Base class for mask heads used in One-Stage Instance Segmentation."""
def __init__(self, init_cfg):
super(BaseMaskHead, self).__init__(init_cfg)
@abstractmethod
def loss(self, **kwargs):
pass
@abstractmethod
def get_results(self, **kwargs):
"""Get precessed :obj:`InstanceData` of multiple images."""
pass
def forward_train(self,
x,
gt_labels,
gt_masks,
img_metas,
gt_bboxes=None,
gt_bboxes_ignore=None,
positive_infos=None,
**kwargs):
"""
Args:
x (list[Tensor] | tuple[Tensor]): Features from FPN.
Each has a shape (B, C, H, W).
gt_labels (list[Tensor]): Ground truth labels of all images.
each has a shape (num_gts,).
gt_masks (list[Tensor]) : Masks for each bbox, has a shape
(num_gts, h , w).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
gt_bboxes (list[Tensor]): Ground truth bboxes of the image,
each item has a shape (num_gts, 4).
gt_bboxes_ignore (list[Tensor], None): Ground truth bboxes to be
ignored, each item has a shape (num_ignored_gts, 4).
positive_infos (list[:obj:`InstanceData`], optional): Information
of positive samples. Used when the label assignment is
done outside the MaskHead, e.g., in BboxHead in
YOLACT or CondInst, etc. When the label assignment is done in
MaskHead, it would be None, like SOLO. All values
in it should have shape (num_positive_samples, *).
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
if positive_infos is None:
outs = self(x)
else:
outs = self(x, positive_infos)
assert isinstance(outs, tuple), 'Forward results should be a tuple, ' \
'even if only one item is returned'
loss = self.loss(
*outs,
gt_labels=gt_labels,
gt_masks=gt_masks,
img_metas=img_metas,
gt_bboxes=gt_bboxes,
gt_bboxes_ignore=gt_bboxes_ignore,
positive_infos=positive_infos,
**kwargs)
return loss
def simple_test(self,
feats,
img_metas,
rescale=False,
instances_list=None,
**kwargs):
"""Test function without test-time augmentation.
Args:
feats (tuple[torch.Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
instances_list (list[obj:`InstanceData`], optional): Detection
results of each image after the post process. Only exist
if there is a `bbox_head`, like `YOLACT`, `CondInst`, etc.
Returns:
list[obj:`InstanceData`]: Instance segmentation \
results of each image after the post process. \
Each item usually contains following keys. \
- scores (Tensor): Classification scores, has a shape
(num_instance,)
- labels (Tensor): Has a shape (num_instances,).
- masks (Tensor): Processed mask results, has a
shape (num_instances, h, w).
"""
if instances_list is None:
outs = self(feats)
else:
outs = self(feats, instances_list=instances_list)
mask_inputs = outs + (img_metas, )
results_list = self.get_results(
*mask_inputs,
rescale=rescale,
instances_list=instances_list,
**kwargs)
return results_list
def onnx_export(self, img, img_metas):
raise NotImplementedError(f'{self.__class__.__name__} does '
f'not support ONNX EXPORT')

@ -0,0 +1,363 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
import mmcv
import numpy as np
import torch
from mmdet.core.visualization.image import imshow_det_bboxes
from ..builder import DETECTORS, build_backbone, build_head, build_neck
from .base import BaseDetector
INF = 1e8
@DETECTORS.register_module()
class SingleStageInstanceSegmentor(BaseDetector):
"""Base class for single-stage instance segmentors."""
def __init__(self,
backbone,
neck=None,
bbox_head=None,
mask_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None):
if pretrained:
warnings.warn('DeprecationWarning: pretrained is deprecated, '
'please use "init_cfg" instead')
backbone.pretrained = pretrained
super(SingleStageInstanceSegmentor, self).__init__(init_cfg=init_cfg)
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
else:
self.neck = None
if bbox_head is not None:
bbox_head.update(train_cfg=copy.deepcopy(train_cfg))
bbox_head.update(test_cfg=copy.deepcopy(test_cfg))
self.bbox_head = build_head(bbox_head)
else:
self.bbox_head = None
assert mask_head, f'`mask_head` must ' \
f'be implemented in {self.__class__.__name__}'
mask_head.update(train_cfg=copy.deepcopy(train_cfg))
mask_head.update(test_cfg=copy.deepcopy(test_cfg))
self.mask_head = build_head(mask_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
def extract_feat(self, img):
"""Directly extract features from the backbone and neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_dummy(self, img):
"""Used for computing network flops.
See `mmdetection/tools/analysis_tools/get_flops.py`
"""
raise NotImplementedError(
f'`forward_dummy` is not implemented in {self.__class__.__name__}')
def forward_train(self,
img,
img_metas,
gt_masks,
gt_labels,
gt_bboxes=None,
gt_bboxes_ignore=None,
**kwargs):
"""
Args:
img (Tensor): Input images of shape (B, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_masks (list[:obj:`BitmapMasks`] | None) : The segmentation
masks for each box.
gt_labels (list[Tensor]): Class indices corresponding to each box
gt_bboxes (list[Tensor]): Each item is the truth boxes
of each image in [tl_x, tl_y, br_x, br_y] format.
Default: None.
gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
gt_masks = [
gt_mask.to_tensor(dtype=torch.bool, device=img.device)
for gt_mask in gt_masks
]
x = self.extract_feat(img)
losses = dict()
# CondInst and YOLACT have bbox_head
if self.bbox_head:
# bbox_head_preds is a tuple
bbox_head_preds = self.bbox_head(x)
# positive_infos is a list of obj:`InstanceData`
# It contains the information about the positive samples
# CondInst, YOLACT
det_losses, positive_infos = self.bbox_head.loss(
*bbox_head_preds,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
gt_masks=gt_masks,
img_metas=img_metas,
gt_bboxes_ignore=gt_bboxes_ignore,
**kwargs)
losses.update(det_losses)
else:
positive_infos = None
mask_loss = self.mask_head.forward_train(
x,
gt_labels,
gt_masks,
img_metas,
positive_infos=positive_infos,
gt_bboxes=gt_bboxes,
gt_bboxes_ignore=gt_bboxes_ignore,
**kwargs)
# avoid loss override
assert not set(mask_loss.keys()) & set(losses.keys())
losses.update(mask_loss)
return losses
def simple_test(self, img, img_metas, rescale=False):
"""Test function without test-time augmentation.
Args:
img (torch.Tensor): Images with shape (B, C, H, W).
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list(tuple): Formatted bbox and mask results of multiple \
images. The outer list corresponds to each image. \
Each tuple contains two type of results of single image:
- bbox_results (list[np.ndarray]): BBox results of
single image. The list corresponds to each class.
each ndarray has a shape (N, 5), N is the number of
bboxes with this category, and last dimension
5 arrange as (x1, y1, x2, y2, scores).
- mask_results (list[np.ndarray]): Mask results of
single image. The list corresponds to each class.
each ndarray has a shape (N, img_h, img_w), N
is the number of masks with this category.
"""
feat = self.extract_feat(img)
if self.bbox_head:
outs = self.bbox_head(feat)
# results_list is list[obj:`InstanceData`]
results_list = self.bbox_head.get_results(
*outs, img_metas=img_metas, cfg=self.test_cfg, rescale=rescale)
else:
results_list = None
results_list = self.mask_head.simple_test(
feat, img_metas, rescale=rescale, instances_list=results_list)
format_results_list = []
for results in results_list:
format_results_list.append(self.format_results(results))
return format_results_list
def format_results(self, results):
"""Format the model predictions according to the interface with
dataset.
Args:
results (:obj:`InstanceData`): Processed
results of single images. Usually contains
following keys.
- scores (Tensor): Classification scores, has shape
(num_instance,)
- labels (Tensor): Has shape (num_instances,).
- masks (Tensor): Processed mask results, has
shape (num_instances, h, w).
Returns:
tuple: Formated bbox and mask results.. It contains two items:
- bbox_results (list[np.ndarray]): BBox results of
single image. The list corresponds to each class.
each ndarray has a shape (N, 5), N is the number of
bboxes with this category, and last dimension
5 arrange as (x1, y1, x2, y2, scores).
- mask_results (list[np.ndarray]): Mask results of
single image. The list corresponds to each class.
each ndarray has shape (N, img_h, img_w), N
is the number of masks with this category.
"""
data_keys = results.keys()
assert 'scores' in data_keys
assert 'labels' in data_keys
assert 'masks' in data_keys, \
'results should contain ' \
'masks when format the results '
mask_results = [[] for _ in range(self.mask_head.num_classes)]
num_masks = len(results)
if num_masks == 0:
bbox_results = [
np.zeros((0, 5), dtype=np.float32)
for _ in range(self.mask_head.num_classes)
]
return bbox_results, mask_results
labels = results.labels.detach().cpu().numpy()
if 'bboxes' not in results:
# creat dummy bbox results to store the scores
results.bboxes = results.scores.new_zeros(len(results), 4)
det_bboxes = torch.cat([results.bboxes, results.scores[:, None]],
dim=-1)
det_bboxes = det_bboxes.detach().cpu().numpy()
bbox_results = [
det_bboxes[labels == i, :]
for i in range(self.mask_head.num_classes)
]
masks = results.masks.detach().cpu().numpy()
for idx in range(num_masks):
mask = masks[idx]
mask_results[labels[idx]].append(mask)
return bbox_results, mask_results
def aug_test(self, imgs, img_metas, rescale=False):
raise NotImplementedError
def show_result(self,
img,
result,
score_thr=0.3,
bbox_color=(72, 101, 241),
text_color=(72, 101, 241),
mask_color=None,
thickness=2,
font_size=13,
win_name='',
show=False,
wait_time=0,
out_file=None):
"""Draw `result` over `img`.
Args:
img (str or Tensor): The image to be displayed.
result (tuple): Format bbox and mask results.
It contains two items:
- bbox_results (list[np.ndarray]): BBox results of
single image. The list corresponds to each class.
each ndarray has a shape (N, 5), N is the number of
bboxes with this category, and last dimension
5 arrange as (x1, y1, x2, y2, scores).
- mask_results (list[np.ndarray]): Mask results of
single image. The list corresponds to each class.
each ndarray has shape (N, img_h, img_w), N
is the number of masks with this category.
score_thr (float, optional): Minimum score of bboxes to be shown.
Default: 0.3.
bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
The tuple of color should be in BGR order. Default: 'green'
text_color (str or tuple(int) or :obj:`Color`):Color of texts.
The tuple of color should be in BGR order. Default: 'green'
mask_color (None or str or tuple(int) or :obj:`Color`):
Color of masks. The tuple of color should be in BGR order.
Default: None
thickness (int): Thickness of lines. Default: 2
font_size (int): Font size of texts. Default: 13
win_name (str): The window name. Default: ''
wait_time (float): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
img (Tensor): Only if not `show` or `out_file`
"""
assert isinstance(result, tuple)
bbox_result, mask_result = result
bboxes = np.vstack(bbox_result)
img = mmcv.imread(img)
img = img.copy()
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(bbox_result)
]
labels = np.concatenate(labels)
if len(labels) == 0:
bboxes = np.zeros([0, 5])
masks = np.zeros([0, 0, 0])
# draw segmentation masks
else:
masks = mmcv.concat_list(mask_result)
if isinstance(masks[0], torch.Tensor):
masks = torch.stack(masks, dim=0).detach().cpu().numpy()
else:
masks = np.stack(masks, axis=0)
# dummy bboxes
if bboxes[:, :4].sum() == 0:
num_masks = len(bboxes)
x_any = masks.any(axis=1)
y_any = masks.any(axis=2)
for idx in range(num_masks):
x = np.where(x_any[idx, :])[0]
y = np.where(y_any[idx, :])[0]
if len(x) > 0 and len(y) > 0:
bboxes[idx, :4] = np.array(
[x[0], y[0], x[-1] + 1, y[-1] + 1],
dtype=np.float32)
# if out_file specified, do not show image in window
if out_file is not None:
show = False
# draw bounding boxes
img = imshow_det_bboxes(
img,
bboxes,
labels,
masks,
class_names=self.CLASSES,
score_thr=score_thr,
bbox_color=bbox_color,
text_color=text_color,
mask_color=mask_color,
thickness=thickness,
font_size=font_size,
win_name=win_name,
show=show,
wait_time=wait_time,
out_file=out_file)
if not (show or out_file):
return img
Loading…
Cancel
Save