OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io/
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
34 lines
1.2 KiB
34 lines
1.2 KiB
# Copyright (c) OpenMMLab. All rights reserved. |
|
from ..builder import HEADS |
|
from .standard_roi_head import StandardRoIHead |
|
|
|
|
|
@HEADS.register_module() |
|
class DoubleHeadRoIHead(StandardRoIHead): |
|
"""RoI head for Double Head RCNN. |
|
|
|
https://arxiv.org/abs/1904.06493 |
|
""" |
|
|
|
def __init__(self, reg_roi_scale_factor, **kwargs): |
|
super(DoubleHeadRoIHead, self).__init__(**kwargs) |
|
self.reg_roi_scale_factor = reg_roi_scale_factor |
|
|
|
def _bbox_forward(self, x, rois): |
|
"""Box head forward function used in both training and testing time.""" |
|
bbox_cls_feats = self.bbox_roi_extractor( |
|
x[:self.bbox_roi_extractor.num_inputs], rois) |
|
bbox_reg_feats = self.bbox_roi_extractor( |
|
x[:self.bbox_roi_extractor.num_inputs], |
|
rois, |
|
roi_scale_factor=self.reg_roi_scale_factor) |
|
if self.with_shared_head: |
|
bbox_cls_feats = self.shared_head(bbox_cls_feats) |
|
bbox_reg_feats = self.shared_head(bbox_reg_feats) |
|
cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats) |
|
|
|
bbox_results = dict( |
|
cls_score=cls_score, |
|
bbox_pred=bbox_pred, |
|
bbox_feats=bbox_cls_feats) |
|
return bbox_results
|
|
|