Add base one-stage segmentor (#5904)
parent
9696414415
commit
d53fbbc587
2 changed files with 479 additions and 0 deletions
@ -0,0 +1,116 @@ |
||||
# Copyright (c) OpenMMLab. All rights reserved. |
||||
from abc import ABCMeta, abstractmethod |
||||
|
||||
from mmcv.runner import BaseModule |
||||
|
||||
|
||||
class BaseMaskHead(BaseModule, metaclass=ABCMeta): |
||||
"""Base class for mask heads used in One-Stage Instance Segmentation.""" |
||||
|
||||
def __init__(self, init_cfg): |
||||
super(BaseMaskHead, self).__init__(init_cfg) |
||||
|
||||
@abstractmethod |
||||
def loss(self, **kwargs): |
||||
pass |
||||
|
||||
@abstractmethod |
||||
def get_results(self, **kwargs): |
||||
"""Get precessed :obj:`InstanceData` of multiple images.""" |
||||
pass |
||||
|
||||
def forward_train(self, |
||||
x, |
||||
gt_labels, |
||||
gt_masks, |
||||
img_metas, |
||||
gt_bboxes=None, |
||||
gt_bboxes_ignore=None, |
||||
positive_infos=None, |
||||
**kwargs): |
||||
""" |
||||
Args: |
||||
x (list[Tensor] | tuple[Tensor]): Features from FPN. |
||||
Each has a shape (B, C, H, W). |
||||
gt_labels (list[Tensor]): Ground truth labels of all images. |
||||
each has a shape (num_gts,). |
||||
gt_masks (list[Tensor]) : Masks for each bbox, has a shape |
||||
(num_gts, h , w). |
||||
img_metas (list[dict]): Meta information of each image, e.g., |
||||
image size, scaling factor, etc. |
||||
gt_bboxes (list[Tensor]): Ground truth bboxes of the image, |
||||
each item has a shape (num_gts, 4). |
||||
gt_bboxes_ignore (list[Tensor], None): Ground truth bboxes to be |
||||
ignored, each item has a shape (num_ignored_gts, 4). |
||||
positive_infos (list[:obj:`InstanceData`], optional): Information |
||||
of positive samples. Used when the label assignment is |
||||
done outside the MaskHead, e.g., in BboxHead in |
||||
YOLACT or CondInst, etc. When the label assignment is done in |
||||
MaskHead, it would be None, like SOLO. All values |
||||
in it should have shape (num_positive_samples, *). |
||||
|
||||
Returns: |
||||
dict[str, Tensor]: A dictionary of loss components. |
||||
""" |
||||
if positive_infos is None: |
||||
outs = self(x) |
||||
else: |
||||
outs = self(x, positive_infos) |
||||
|
||||
assert isinstance(outs, tuple), 'Forward results should be a tuple, ' \ |
||||
'even if only one item is returned' |
||||
loss = self.loss( |
||||
*outs, |
||||
gt_labels=gt_labels, |
||||
gt_masks=gt_masks, |
||||
img_metas=img_metas, |
||||
gt_bboxes=gt_bboxes, |
||||
gt_bboxes_ignore=gt_bboxes_ignore, |
||||
positive_infos=positive_infos, |
||||
**kwargs) |
||||
return loss |
||||
|
||||
def simple_test(self, |
||||
feats, |
||||
img_metas, |
||||
rescale=False, |
||||
instances_list=None, |
||||
**kwargs): |
||||
"""Test function without test-time augmentation. |
||||
|
||||
Args: |
||||
feats (tuple[torch.Tensor]): Multi-level features from the |
||||
upstream network, each is a 4D-tensor. |
||||
img_metas (list[dict]): List of image information. |
||||
rescale (bool, optional): Whether to rescale the results. |
||||
Defaults to False. |
||||
instances_list (list[obj:`InstanceData`], optional): Detection |
||||
results of each image after the post process. Only exist |
||||
if there is a `bbox_head`, like `YOLACT`, `CondInst`, etc. |
||||
|
||||
Returns: |
||||
list[obj:`InstanceData`]: Instance segmentation \ |
||||
results of each image after the post process. \ |
||||
Each item usually contains following keys. \ |
||||
|
||||
- scores (Tensor): Classification scores, has a shape |
||||
(num_instance,) |
||||
- labels (Tensor): Has a shape (num_instances,). |
||||
- masks (Tensor): Processed mask results, has a |
||||
shape (num_instances, h, w). |
||||
""" |
||||
if instances_list is None: |
||||
outs = self(feats) |
||||
else: |
||||
outs = self(feats, instances_list=instances_list) |
||||
mask_inputs = outs + (img_metas, ) |
||||
results_list = self.get_results( |
||||
*mask_inputs, |
||||
rescale=rescale, |
||||
instances_list=instances_list, |
||||
**kwargs) |
||||
return results_list |
||||
|
||||
def onnx_export(self, img, img_metas): |
||||
raise NotImplementedError(f'{self.__class__.__name__} does ' |
||||
f'not support ONNX EXPORT') |
@ -0,0 +1,363 @@ |
||||
# Copyright (c) OpenMMLab. All rights reserved. |
||||
import copy |
||||
import warnings |
||||
|
||||
import mmcv |
||||
import numpy as np |
||||
import torch |
||||
|
||||
from mmdet.core.visualization.image import imshow_det_bboxes |
||||
from ..builder import DETECTORS, build_backbone, build_head, build_neck |
||||
from .base import BaseDetector |
||||
|
||||
INF = 1e8 |
||||
|
||||
|
||||
@DETECTORS.register_module() |
||||
class SingleStageInstanceSegmentor(BaseDetector): |
||||
"""Base class for single-stage instance segmentors.""" |
||||
|
||||
def __init__(self, |
||||
backbone, |
||||
neck=None, |
||||
bbox_head=None, |
||||
mask_head=None, |
||||
train_cfg=None, |
||||
test_cfg=None, |
||||
pretrained=None, |
||||
init_cfg=None): |
||||
|
||||
if pretrained: |
||||
warnings.warn('DeprecationWarning: pretrained is deprecated, ' |
||||
'please use "init_cfg" instead') |
||||
backbone.pretrained = pretrained |
||||
super(SingleStageInstanceSegmentor, self).__init__(init_cfg=init_cfg) |
||||
self.backbone = build_backbone(backbone) |
||||
if neck is not None: |
||||
self.neck = build_neck(neck) |
||||
else: |
||||
self.neck = None |
||||
if bbox_head is not None: |
||||
bbox_head.update(train_cfg=copy.deepcopy(train_cfg)) |
||||
bbox_head.update(test_cfg=copy.deepcopy(test_cfg)) |
||||
self.bbox_head = build_head(bbox_head) |
||||
else: |
||||
self.bbox_head = None |
||||
|
||||
assert mask_head, f'`mask_head` must ' \ |
||||
f'be implemented in {self.__class__.__name__}' |
||||
mask_head.update(train_cfg=copy.deepcopy(train_cfg)) |
||||
mask_head.update(test_cfg=copy.deepcopy(test_cfg)) |
||||
self.mask_head = build_head(mask_head) |
||||
|
||||
self.train_cfg = train_cfg |
||||
self.test_cfg = test_cfg |
||||
|
||||
def extract_feat(self, img): |
||||
"""Directly extract features from the backbone and neck.""" |
||||
x = self.backbone(img) |
||||
if self.with_neck: |
||||
x = self.neck(x) |
||||
return x |
||||
|
||||
def forward_dummy(self, img): |
||||
"""Used for computing network flops. |
||||
|
||||
See `mmdetection/tools/analysis_tools/get_flops.py` |
||||
""" |
||||
raise NotImplementedError( |
||||
f'`forward_dummy` is not implemented in {self.__class__.__name__}') |
||||
|
||||
def forward_train(self, |
||||
img, |
||||
img_metas, |
||||
gt_masks, |
||||
gt_labels, |
||||
gt_bboxes=None, |
||||
gt_bboxes_ignore=None, |
||||
**kwargs): |
||||
""" |
||||
Args: |
||||
img (Tensor): Input images of shape (B, C, H, W). |
||||
Typically these should be mean centered and std scaled. |
||||
img_metas (list[dict]): A List of image info dict where each dict |
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain |
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. |
||||
For details on the values of these keys see |
||||
:class:`mmdet.datasets.pipelines.Collect`. |
||||
gt_masks (list[:obj:`BitmapMasks`] | None) : The segmentation |
||||
masks for each box. |
||||
gt_labels (list[Tensor]): Class indices corresponding to each box |
||||
gt_bboxes (list[Tensor]): Each item is the truth boxes |
||||
of each image in [tl_x, tl_y, br_x, br_y] format. |
||||
Default: None. |
||||
gt_bboxes_ignore (list[Tensor] | None): Specify which bounding |
||||
boxes can be ignored when computing the loss. |
||||
|
||||
Returns: |
||||
dict[str, Tensor]: A dictionary of loss components. |
||||
""" |
||||
|
||||
gt_masks = [ |
||||
gt_mask.to_tensor(dtype=torch.bool, device=img.device) |
||||
for gt_mask in gt_masks |
||||
] |
||||
x = self.extract_feat(img) |
||||
losses = dict() |
||||
|
||||
# CondInst and YOLACT have bbox_head |
||||
if self.bbox_head: |
||||
# bbox_head_preds is a tuple |
||||
bbox_head_preds = self.bbox_head(x) |
||||
# positive_infos is a list of obj:`InstanceData` |
||||
# It contains the information about the positive samples |
||||
# CondInst, YOLACT |
||||
det_losses, positive_infos = self.bbox_head.loss( |
||||
*bbox_head_preds, |
||||
gt_bboxes=gt_bboxes, |
||||
gt_labels=gt_labels, |
||||
gt_masks=gt_masks, |
||||
img_metas=img_metas, |
||||
gt_bboxes_ignore=gt_bboxes_ignore, |
||||
**kwargs) |
||||
losses.update(det_losses) |
||||
else: |
||||
positive_infos = None |
||||
|
||||
mask_loss = self.mask_head.forward_train( |
||||
x, |
||||
gt_labels, |
||||
gt_masks, |
||||
img_metas, |
||||
positive_infos=positive_infos, |
||||
gt_bboxes=gt_bboxes, |
||||
gt_bboxes_ignore=gt_bboxes_ignore, |
||||
**kwargs) |
||||
# avoid loss override |
||||
assert not set(mask_loss.keys()) & set(losses.keys()) |
||||
|
||||
losses.update(mask_loss) |
||||
return losses |
||||
|
||||
def simple_test(self, img, img_metas, rescale=False): |
||||
"""Test function without test-time augmentation. |
||||
|
||||
Args: |
||||
img (torch.Tensor): Images with shape (B, C, H, W). |
||||
img_metas (list[dict]): List of image information. |
||||
rescale (bool, optional): Whether to rescale the results. |
||||
Defaults to False. |
||||
|
||||
Returns: |
||||
list(tuple): Formatted bbox and mask results of multiple \ |
||||
images. The outer list corresponds to each image. \ |
||||
Each tuple contains two type of results of single image: |
||||
|
||||
- bbox_results (list[np.ndarray]): BBox results of |
||||
single image. The list corresponds to each class. |
||||
each ndarray has a shape (N, 5), N is the number of |
||||
bboxes with this category, and last dimension |
||||
5 arrange as (x1, y1, x2, y2, scores). |
||||
- mask_results (list[np.ndarray]): Mask results of |
||||
single image. The list corresponds to each class. |
||||
each ndarray has a shape (N, img_h, img_w), N |
||||
is the number of masks with this category. |
||||
""" |
||||
feat = self.extract_feat(img) |
||||
if self.bbox_head: |
||||
outs = self.bbox_head(feat) |
||||
# results_list is list[obj:`InstanceData`] |
||||
results_list = self.bbox_head.get_results( |
||||
*outs, img_metas=img_metas, cfg=self.test_cfg, rescale=rescale) |
||||
else: |
||||
results_list = None |
||||
|
||||
results_list = self.mask_head.simple_test( |
||||
feat, img_metas, rescale=rescale, instances_list=results_list) |
||||
|
||||
format_results_list = [] |
||||
for results in results_list: |
||||
format_results_list.append(self.format_results(results)) |
||||
|
||||
return format_results_list |
||||
|
||||
def format_results(self, results): |
||||
"""Format the model predictions according to the interface with |
||||
dataset. |
||||
|
||||
Args: |
||||
results (:obj:`InstanceData`): Processed |
||||
results of single images. Usually contains |
||||
following keys. |
||||
|
||||
- scores (Tensor): Classification scores, has shape |
||||
(num_instance,) |
||||
- labels (Tensor): Has shape (num_instances,). |
||||
- masks (Tensor): Processed mask results, has |
||||
shape (num_instances, h, w). |
||||
|
||||
Returns: |
||||
tuple: Formated bbox and mask results.. It contains two items: |
||||
|
||||
- bbox_results (list[np.ndarray]): BBox results of |
||||
single image. The list corresponds to each class. |
||||
each ndarray has a shape (N, 5), N is the number of |
||||
bboxes with this category, and last dimension |
||||
5 arrange as (x1, y1, x2, y2, scores). |
||||
- mask_results (list[np.ndarray]): Mask results of |
||||
single image. The list corresponds to each class. |
||||
each ndarray has shape (N, img_h, img_w), N |
||||
is the number of masks with this category. |
||||
""" |
||||
data_keys = results.keys() |
||||
assert 'scores' in data_keys |
||||
assert 'labels' in data_keys |
||||
|
||||
assert 'masks' in data_keys, \ |
||||
'results should contain ' \ |
||||
'masks when format the results ' |
||||
mask_results = [[] for _ in range(self.mask_head.num_classes)] |
||||
|
||||
num_masks = len(results) |
||||
|
||||
if num_masks == 0: |
||||
bbox_results = [ |
||||
np.zeros((0, 5), dtype=np.float32) |
||||
for _ in range(self.mask_head.num_classes) |
||||
] |
||||
return bbox_results, mask_results |
||||
|
||||
labels = results.labels.detach().cpu().numpy() |
||||
|
||||
if 'bboxes' not in results: |
||||
# creat dummy bbox results to store the scores |
||||
results.bboxes = results.scores.new_zeros(len(results), 4) |
||||
|
||||
det_bboxes = torch.cat([results.bboxes, results.scores[:, None]], |
||||
dim=-1) |
||||
det_bboxes = det_bboxes.detach().cpu().numpy() |
||||
bbox_results = [ |
||||
det_bboxes[labels == i, :] |
||||
for i in range(self.mask_head.num_classes) |
||||
] |
||||
|
||||
masks = results.masks.detach().cpu().numpy() |
||||
|
||||
for idx in range(num_masks): |
||||
mask = masks[idx] |
||||
mask_results[labels[idx]].append(mask) |
||||
|
||||
return bbox_results, mask_results |
||||
|
||||
def aug_test(self, imgs, img_metas, rescale=False): |
||||
raise NotImplementedError |
||||
|
||||
def show_result(self, |
||||
img, |
||||
result, |
||||
score_thr=0.3, |
||||
bbox_color=(72, 101, 241), |
||||
text_color=(72, 101, 241), |
||||
mask_color=None, |
||||
thickness=2, |
||||
font_size=13, |
||||
win_name='', |
||||
show=False, |
||||
wait_time=0, |
||||
out_file=None): |
||||
"""Draw `result` over `img`. |
||||
|
||||
Args: |
||||
img (str or Tensor): The image to be displayed. |
||||
result (tuple): Format bbox and mask results. |
||||
It contains two items: |
||||
|
||||
- bbox_results (list[np.ndarray]): BBox results of |
||||
single image. The list corresponds to each class. |
||||
each ndarray has a shape (N, 5), N is the number of |
||||
bboxes with this category, and last dimension |
||||
5 arrange as (x1, y1, x2, y2, scores). |
||||
- mask_results (list[np.ndarray]): Mask results of |
||||
single image. The list corresponds to each class. |
||||
each ndarray has shape (N, img_h, img_w), N |
||||
is the number of masks with this category. |
||||
|
||||
score_thr (float, optional): Minimum score of bboxes to be shown. |
||||
Default: 0.3. |
||||
bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines. |
||||
The tuple of color should be in BGR order. Default: 'green' |
||||
text_color (str or tuple(int) or :obj:`Color`):Color of texts. |
||||
The tuple of color should be in BGR order. Default: 'green' |
||||
mask_color (None or str or tuple(int) or :obj:`Color`): |
||||
Color of masks. The tuple of color should be in BGR order. |
||||
Default: None |
||||
thickness (int): Thickness of lines. Default: 2 |
||||
font_size (int): Font size of texts. Default: 13 |
||||
win_name (str): The window name. Default: '' |
||||
wait_time (float): Value of waitKey param. |
||||
Default: 0. |
||||
show (bool): Whether to show the image. |
||||
Default: False. |
||||
out_file (str or None): The filename to write the image. |
||||
Default: None. |
||||
|
||||
Returns: |
||||
img (Tensor): Only if not `show` or `out_file` |
||||
""" |
||||
|
||||
assert isinstance(result, tuple) |
||||
bbox_result, mask_result = result |
||||
bboxes = np.vstack(bbox_result) |
||||
img = mmcv.imread(img) |
||||
img = img.copy() |
||||
labels = [ |
||||
np.full(bbox.shape[0], i, dtype=np.int32) |
||||
for i, bbox in enumerate(bbox_result) |
||||
] |
||||
labels = np.concatenate(labels) |
||||
if len(labels) == 0: |
||||
bboxes = np.zeros([0, 5]) |
||||
masks = np.zeros([0, 0, 0]) |
||||
# draw segmentation masks |
||||
else: |
||||
masks = mmcv.concat_list(mask_result) |
||||
|
||||
if isinstance(masks[0], torch.Tensor): |
||||
masks = torch.stack(masks, dim=0).detach().cpu().numpy() |
||||
else: |
||||
masks = np.stack(masks, axis=0) |
||||
# dummy bboxes |
||||
if bboxes[:, :4].sum() == 0: |
||||
num_masks = len(bboxes) |
||||
x_any = masks.any(axis=1) |
||||
y_any = masks.any(axis=2) |
||||
for idx in range(num_masks): |
||||
x = np.where(x_any[idx, :])[0] |
||||
y = np.where(y_any[idx, :])[0] |
||||
if len(x) > 0 and len(y) > 0: |
||||
bboxes[idx, :4] = np.array( |
||||
[x[0], y[0], x[-1] + 1, y[-1] + 1], |
||||
dtype=np.float32) |
||||
# if out_file specified, do not show image in window |
||||
if out_file is not None: |
||||
show = False |
||||
# draw bounding boxes |
||||
img = imshow_det_bboxes( |
||||
img, |
||||
bboxes, |
||||
labels, |
||||
masks, |
||||
class_names=self.CLASSES, |
||||
score_thr=score_thr, |
||||
bbox_color=bbox_color, |
||||
text_color=text_color, |
||||
mask_color=mask_color, |
||||
thickness=thickness, |
||||
font_size=font_size, |
||||
win_name=win_name, |
||||
show=show, |
||||
wait_time=wait_time, |
||||
out_file=out_file) |
||||
|
||||
if not (show or out_file): |
||||
return img |
Loading…
Reference in new issue