OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
86 lines
2.8 KiB
86 lines
2.8 KiB
# Copyright (c) OpenMMLab. All rights reserved. |
|
from abc import ABCMeta, abstractmethod |
|
|
|
import torch.nn.functional as F |
|
from mmcv.runner import BaseModule, force_fp32 |
|
|
|
from ..builder import build_loss |
|
from ..utils import interpolate_as |
|
|
|
|
|
class BaseSemanticHead(BaseModule, metaclass=ABCMeta): |
|
"""Base module of Semantic Head. |
|
|
|
Args: |
|
num_classes (int): the number of classes. |
|
init_cfg (dict): the initialization config. |
|
loss_seg (dict): the loss of the semantic head. |
|
""" |
|
|
|
def __init__(self, |
|
num_classes, |
|
init_cfg=None, |
|
loss_seg=dict( |
|
type='CrossEntropyLoss', |
|
ignore_index=255, |
|
loss_weight=1.0)): |
|
super(BaseSemanticHead, self).__init__(init_cfg) |
|
self.loss_seg = build_loss(loss_seg) |
|
self.num_classes = num_classes |
|
|
|
@force_fp32(apply_to=('seg_preds', )) |
|
def loss(self, seg_preds, gt_semantic_seg): |
|
"""Get the loss of semantic head. |
|
|
|
Args: |
|
seg_preds (Tensor): The input logits with the shape (N, C, H, W). |
|
gt_semantic_seg: The ground truth of semantic segmentation with |
|
the shape (N, H, W). |
|
label_bias: The starting number of the semantic label. |
|
Default: 1. |
|
|
|
Returns: |
|
dict: the loss of semantic head. |
|
""" |
|
if seg_preds.shape[-2:] != gt_semantic_seg.shape[-2:]: |
|
seg_preds = interpolate_as(seg_preds, gt_semantic_seg) |
|
seg_preds = seg_preds.permute((0, 2, 3, 1)) |
|
|
|
loss_seg = self.loss_seg( |
|
seg_preds.reshape(-1, self.num_classes), # => [NxHxW, C] |
|
gt_semantic_seg.reshape(-1).long()) |
|
return dict(loss_seg=loss_seg) |
|
|
|
@abstractmethod |
|
def forward(self, x): |
|
"""Placeholder of forward function. |
|
|
|
Returns: |
|
dict[str, Tensor]: A dictionary, including features |
|
and predicted scores. Required keys: 'seg_preds' |
|
and 'feats'. |
|
""" |
|
pass |
|
|
|
def forward_train(self, x, gt_semantic_seg): |
|
output = self.forward(x) |
|
seg_preds = output['seg_preds'] |
|
return self.loss(seg_preds, gt_semantic_seg) |
|
|
|
def simple_test(self, x, img_metas, rescale=False): |
|
output = self.forward(x) |
|
seg_preds = output['seg_preds'] |
|
seg_preds = F.interpolate( |
|
seg_preds, |
|
size=img_metas[0]['pad_shape'][:2], |
|
mode='bilinear', |
|
align_corners=False) |
|
|
|
if rescale: |
|
h, w, _ = img_metas[0]['img_shape'] |
|
seg_preds = seg_preds[:, :, :h, :w] |
|
|
|
h, w, _ = img_metas[0]['ori_shape'] |
|
seg_preds = F.interpolate( |
|
seg_preds, size=(h, w), mode='bilinear', align_corners=False) |
|
return seg_preds
|
|
|