OpenMMLab Detection Toolbox and Benchmark https://mmdetection.readthedocs.io/
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

53 lines
1.9 KiB

import logging
import sys
from mmdet.core import merge_aug_proposals
logger = logging.getLogger(__name__)
if sys.version_info >= (3, 7):
from mmdet.utils.contextmanagers import completed
class RPNTestMixin(object):
if sys.version_info >= (3, 7):
async def async_test_rpn(self, x, img_metas):
sleep_interval = self.rpn_head.test_cfg.pop(
'async_sleep_interval', 0.025)
async with completed(
__name__, 'rpn_head_forward',
sleep_interval=sleep_interval):
rpn_outs = self.rpn_head(x)
proposal_list = self.rpn_head.get_bboxes(*rpn_outs, img_metas)
return proposal_list
def simple_test_rpn(self, x, img_metas):
rpn_outs = self.rpn_head(x)
proposal_list = self.rpn_head.get_bboxes(*rpn_outs, img_metas)
return proposal_list
def aug_test_rpn(self, feats, img_metas):
samples_per_gpu = len(img_metas[0])
aug_proposals = [[] for _ in range(samples_per_gpu)]
for x, img_meta in zip(feats, img_metas):
proposal_list = self.simple_test_rpn(x, img_meta)
for i, proposals in enumerate(proposal_list):
aug_proposals[i].append(proposals)
# reorganize the order of 'img_metas' to match the dimensions
# of 'aug_proposals'
aug_img_metas = []
for i in range(samples_per_gpu):
aug_img_meta = []
for j in range(len(img_metas)):
aug_img_meta.append(img_metas[j][i])
aug_img_metas.append(aug_img_meta)
# after merging, proposals will be rescaled to the original image size
merged_proposals = [
merge_aug_proposals(proposals, aug_img_meta,
self.rpn_head.test_cfg)
for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas)
]
return merged_proposals