Add doc string for datasets (#3130)

* add docstring for pipelines

* add docs string for transform pipeline

* add dataset docs

* resovle comments

* refactor table

* resovle comments

* delete __repr__

* resolve comments

* minor update

* update readme

* minor update

* rename api
pull/3134/head
Jerry Jiarui XU 5 years ago committed by GitHub
parent 35ec6d13f4
commit 792273be65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 64
      README.md
  2. 2
      docs/api.rst
  3. 3
      docs/model_zoo.md
  4. 39
      mmdet/datasets/cityscapes.py
  5. 53
      mmdet/datasets/coco.py
  6. 66
      mmdet/datasets/custom.py
  7. 32
      mmdet/datasets/dataset_wrappers.py
  8. 19
      mmdet/datasets/lvis.py
  9. 15
      mmdet/datasets/pipelines/compose.py
  10. 113
      mmdet/datasets/pipelines/formating.py
  11. 103
      mmdet/datasets/pipelines/loading.py
  12. 38
      mmdet/datasets/pipelines/test_time_aug.py
  13. 252
      mmdet/datasets/pipelines/transforms.py
  14. 22
      mmdet/datasets/voc.py
  15. 9
      mmdet/datasets/wider_face.py
  16. 37
      mmdet/datasets/xml_style.py
  17. 4
      mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py

@ -46,39 +46,41 @@ A comparison between v1.x and v2.0 codebases can be found in [compatibility.md](
## Benchmark and model zoo ## Benchmark and model zoo
Supported methods and backbones are shown in the below table.
Results and models are available in the [model zoo](docs/model_zoo.md). Results and models are available in the [model zoo](docs/model_zoo.md).
| | ResNet | ResNeXt | SENet | VGG | HRNet | RegNetX | Res2Net | Supported backbones:
|--------------------|:--------:|:--------:|:--------:|:--------:|:-----:|:--------:|:-----:| - [x] ResNet
| RPN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] ResNeXt
| Fast R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] VGG
| Faster R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ✓ | ✓ | - [x] HRNet
| Mask R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ✓ | ✓ | - [x] RegNet
| Cascade R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ✓ | - [x] Res2Net
| Cascade Mask R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ✓ |
| SSD | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | Supported methods:
| RetinaNet | ✓ | ✓ | ☐ | ✗ | ✓ | ✓ | ☐ | - [x] [RPN](configs/rpn)
| GHM | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [Fast R-CNN](configs/fast_rcnn)
| Mask Scoring R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [Faster R-CNN](configs/faster_rcnn)
| Double-Head R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [Mask R-CNN](configs/mask_rcnn)
| Grid R-CNN (Plus) | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [Cascade R-CNN](configs/cascade_rcnn)
| Hybrid Task Cascade| ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ✓ | - [x] [Cascade Mask R-CNN](configs/cascade_rcnn)
| Libra R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [SSD](configs/ssd)
| Guided Anchoring | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [RetinaNet](configs/retinanet)
| FCOS | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [GHM](configs/ghm)
| RepPoints | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [Mask Scoring R-CNN](configs/ms_rcnn)
| Foveabox | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [Double-Head R-CNN](configs/double_heads)
| FreeAnchor | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [Hybrid Task Cascade](configs/htc)
| NAS-FPN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [Libra R-CNN](configs/libra_rcnn)
| ATSS | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [Guided Anchoring](configs/guided_anchoring)
| FSAF | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [FCOS](configs/fcos)
| PAFPN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [RepPoints](configs/reppoints)
| NAS-FCOS | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [Foveabox](configs/foveabox)
| PISA | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [FreeAnchor](configs/free_anchor)
| Dynamic R-CNN | ✓ | ✓ | ☐ | ✗ | ✓ | ☐ | ☐ | - [x] [NAS-FPN](configs/nas_fpn)
- [x] [ATSS](configs/atss)
Other features - [x] [FSAF](configs/fsaf)
- [x] [PAFPN](configs/pafpn)
- [x] [Dynamic R-CNN](configs/dynamic_rcnn)
- [x] [PointRend](configs/point_rend)
- [x] [CARAFE](configs/carafe/README.md) - [x] [CARAFE](configs/carafe/README.md)
- [x] [DCNv2](configs/dcn/README.md) - [x] [DCNv2](configs/dcn/README.md)
- [x] [Group Normalization](configs/gn/README.md) - [x] [Group Normalization](configs/gn/README.md)

@ -1,4 +1,4 @@
API Documentation API Reference
================= =================
mmdet.apis mmdet.apis

@ -135,6 +135,9 @@ Please refer to [GRoIE](https://github.com/open-mmlab/mmdetection/blob/master/co
### Dynamic R-CNN ### Dynamic R-CNN
Please refer to [Dynamic R-CNN](https://github.com/open-mmlab/mmdetection/blob/master/configs/dynamic_rcnn) for details. Please refer to [Dynamic R-CNN](https://github.com/open-mmlab/mmdetection/blob/master/configs/dynamic_rcnn) for details.
### PointRend
Please refer to [PointRend](https://github.com/open-mmlab/mmdetection/blob/master/configs/point_rend) for details.
### Other datasets ### Other datasets
We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face). We also benchmark some methods on [PASCAL VOC](https://github.com/open-mmlab/mmdetection/blob/master/configs/pascal_voc), [Cityscapes](https://github.com/open-mmlab/mmdetection/blob/master/configs/cityscapes) and [WIDER FACE](https://github.com/open-mmlab/mmdetection/blob/master/configs/wider_face).

@ -95,7 +95,7 @@ class CityscapesDataset(CocoDataset):
"""Dump the detection results to a txt file. """Dump the detection results to a txt file.
Args: Args:
results (list[list | tuple | ndarray]): Testing results of the results (list[list | tuple]): Testing results of the
dataset. dataset.
outfile_prefix (str): The filename prefix of the json files. outfile_prefix (str): The filename prefix of the json files.
If the prefix is "somepath/xxx", If the prefix is "somepath/xxx",
@ -198,14 +198,27 @@ class CityscapesDataset(CocoDataset):
classwise=False, classwise=False,
proposal_nums=(100, 300, 1000), proposal_nums=(100, 300, 1000),
iou_thrs=np.arange(0.5, 0.96, 0.05)): iou_thrs=np.arange(0.5, 0.96, 0.05)):
"""Evaluation in Cityscapes protocol. """Evaluation in Cityscapes/COCO protocol.
Args: Args:
results (list): Testing results of the dataset. results (list[list | tuple]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str]): Metrics to be evaluated. Options are
'bbox', 'segm', 'proposal', 'proposal_fast'.
logger (logging.Logger | str | None): Logger used for printing logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
outfile_prefix (str | None): outfile_prefix (str | None): The prefix of output file. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If results are evaluated with COCO protocol, it would be the
prefix of output json file. For example, the metric is 'bbox'
and 'segm', then json files would be "a/b/prefix.bbox.json" and
"a/b/prefix.segm.json".
If results are evaluated with cityscapes protocol, it would be
the prefix of output txt/png files. The output files would be
png images under folder "a/b/prefix/xxx/" and the file name of
images would be written into a txt file
"a/b/prefix/xxx_pred.txt", where "xxx" is the video name of
cityscapes. If not specified, a temp file will be created.
Default: None.
classwise (bool): Whether to evaluating the AP for each class. classwise (bool): Whether to evaluating the AP for each class.
proposal_nums (Sequence[int]): Proposal number used for evaluating proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000. recalls, such as recall@100, recall@1000.
@ -215,7 +228,8 @@ class CityscapesDataset(CocoDataset):
also be computed. Default: 0.5. also be computed. Default: 0.5.
Returns: Returns:
dict[str: float] dict[str, float]: COCO style evaluation metric or cityscapes mAP
and AP@50.
""" """
eval_results = dict() eval_results = dict()
@ -244,6 +258,19 @@ class CityscapesDataset(CocoDataset):
return eval_results return eval_results
def _evaluate_cityscapes(self, results, txtfile_prefix, logger): def _evaluate_cityscapes(self, results, txtfile_prefix, logger):
"""Evaluation in Cityscapes protocol.
Args:
results (list): Testing results of the dataset.
txtfile_prefix (str | None): The prefix of output txt file
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str: float]: Cityscapes evaluation results, contains 'mAP'
and 'AP@50'.
"""
try: try:
import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval # noqa import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval # noqa
except ImportError: except ImportError:

@ -34,6 +34,15 @@ class CocoDataset(CustomDataset):
'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush') 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
def load_annotations(self, ann_file): def load_annotations(self, ann_file):
"""Load annotation from COCO style annotation file.
Args:
ann_file (str): Path of annotation file.
Returns:
list[dict]: Annotation info from COCO api.
"""
self.coco = COCO(ann_file) self.coco = COCO(ann_file)
self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES) self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
@ -46,12 +55,30 @@ class CocoDataset(CustomDataset):
return data_infos return data_infos
def get_ann_info(self, idx): def get_ann_info(self, idx):
"""Get COCO annotation by index.
Args:
idx (int): Index of data.
Returns:
dict: Annotation info of specified index.
"""
img_id = self.data_infos[idx]['id'] img_id = self.data_infos[idx]['id']
ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
ann_info = self.coco.load_anns(ann_ids) ann_info = self.coco.load_anns(ann_ids)
return self._parse_ann_info(self.data_infos[idx], ann_info) return self._parse_ann_info(self.data_infos[idx], ann_info)
def get_cat_ids(self, idx): def get_cat_ids(self, idx):
"""Get COCO category ids by index.
Args:
idx (int): Index of data.
Returns:
list[int]: All categories in the image of specified index.
"""
img_id = self.data_infos[idx]['id'] img_id = self.data_infos[idx]['id']
ann_ids = self.coco.get_ann_ids(img_ids=[img_id]) ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
ann_info = self.coco.load_anns(ann_ids) ann_info = self.coco.load_anns(ann_ids)
@ -153,6 +180,17 @@ class CocoDataset(CustomDataset):
return ann return ann
def xyxy2xywh(self, bbox): def xyxy2xywh(self, bbox):
"""Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
evaluation.
Args:
bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
``xyxy`` order.
Returns:
list[float]: The converted bounding boxes, in ``xywh`` order.
"""
_bbox = bbox.tolist() _bbox = bbox.tolist()
return [ return [
_bbox[0], _bbox[0],
@ -162,6 +200,7 @@ class CocoDataset(CustomDataset):
] ]
def _proposal2json(self, results): def _proposal2json(self, results):
"""Convert proposal results to COCO json style"""
json_results = [] json_results = []
for idx in range(len(self)): for idx in range(len(self)):
img_id = self.img_ids[idx] img_id = self.img_ids[idx]
@ -176,6 +215,7 @@ class CocoDataset(CustomDataset):
return json_results return json_results
def _det2json(self, results): def _det2json(self, results):
"""Convert detection results to COCO json style"""
json_results = [] json_results = []
for idx in range(len(self)): for idx in range(len(self)):
img_id = self.img_ids[idx] img_id = self.img_ids[idx]
@ -192,6 +232,7 @@ class CocoDataset(CustomDataset):
return json_results return json_results
def _segm2json(self, results): def _segm2json(self, results):
"""Convert instance segmentation results to COCO json style"""
bbox_json_results = [] bbox_json_results = []
segm_json_results = [] segm_json_results = []
for idx in range(len(self)): for idx in range(len(self)):
@ -229,7 +270,7 @@ class CocoDataset(CustomDataset):
return bbox_json_results, segm_json_results return bbox_json_results, segm_json_results
def results2json(self, results, outfile_prefix): def results2json(self, results, outfile_prefix):
"""Dump the detection results to a json file. """Dump the detection results to a COCO style json file.
There are 3 types of results: proposals, bbox predictions, mask There are 3 types of results: proposals, bbox predictions, mask
predictions, and they have different data types. This method will predictions, and they have different data types. This method will
@ -296,7 +337,8 @@ class CocoDataset(CustomDataset):
"""Format the results to json (standard format for COCO evaluation). """Format the results to json (standard format for COCO evaluation).
Args: Args:
results (list): Testing results of the dataset. results (list[tuple | numpy.ndarray]): Testing results of the
dataset.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str | None): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix". the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None. If not specified, a temp file will be created. Default: None.
@ -330,8 +372,9 @@ class CocoDataset(CustomDataset):
"""Evaluation in COCO protocol. """Evaluation in COCO protocol.
Args: Args:
results (list): Testing results of the dataset. results (list[list | tuple]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str]): Metrics to be evaluated. Options are
'bbox', 'segm', 'proposal', 'proposal_fast'.
logger (logging.Logger | str | None): Logger used for printing logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
jsonfile_prefix (str | None): The prefix of json files. It includes jsonfile_prefix (str | None): The prefix of json files. It includes
@ -346,7 +389,7 @@ class CocoDataset(CustomDataset):
also be computed. Default: 0.5. also be computed. Default: 0.5.
Returns: Returns:
dict[str: float] dict[str, float]: COCO style evaluation metric.
""" """
metrics = metric if isinstance(metric, list) else [metric] metrics = metric if isinstance(metric, list) else [metric]

@ -32,6 +32,17 @@ class CustomDataset(Dataset):
}, },
... ...
] ]
Args:
ann_file (str): Annotation file path.
pipeline (list[dict]): Processing pipeline.
classes (str | Sequence[str], optional): Specify classes to load.
If is None, ``cls.CLASSES`` will be used. Default: None.
data_root (str, optional): Data root for ``ann_file``,
``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
test_mode (bool, optional): If set True, annotation will not be loaded.
filter_empty_gt (bool, optional): If set true, images without bounding
boxes will be filtered out.
""" """
CLASSES = None CLASSES = None
@ -90,21 +101,43 @@ class CustomDataset(Dataset):
self.pipeline = Compose(pipeline) self.pipeline = Compose(pipeline)
def __len__(self): def __len__(self):
"""Total number of samples of data"""
return len(self.data_infos) return len(self.data_infos)
def load_annotations(self, ann_file): def load_annotations(self, ann_file):
"""Load annotation from annotation file"""
return mmcv.load(ann_file) return mmcv.load(ann_file)
def load_proposals(self, proposal_file): def load_proposals(self, proposal_file):
"""Load proposal from proposal file"""
return mmcv.load(proposal_file) return mmcv.load(proposal_file)
def get_ann_info(self, idx): def get_ann_info(self, idx):
"""Get annotation by index
Args:
idx (int): Index of data.
Returns:
dict: Annotation info of specified index.
"""
return self.data_infos[idx]['ann'] return self.data_infos[idx]['ann']
def get_cat_ids(self, idx): def get_cat_ids(self, idx):
"""Get category ids by index
Args:
idx (int): Index of data.
Returns:
list[int]: All categories in the image of specified index.
"""
return self.data_infos[idx]['ann']['labels'].astype(np.int).tolist() return self.data_infos[idx]['ann']['labels'].astype(np.int).tolist()
def pre_pipeline(self, results): def pre_pipeline(self, results):
"""Prepare results dict for pipeline"""
results['img_prefix'] = self.img_prefix results['img_prefix'] = self.img_prefix
results['seg_prefix'] = self.seg_prefix results['seg_prefix'] = self.seg_prefix
results['proposal_file'] = self.proposal_file results['proposal_file'] = self.proposal_file
@ -133,10 +166,21 @@ class CustomDataset(Dataset):
self.flag[i] = 1 self.flag[i] = 1
def _rand_another(self, idx): def _rand_another(self, idx):
"""Get another random index from the same group as the given index"""
pool = np.where(self.flag == self.flag[idx])[0] pool = np.where(self.flag == self.flag[idx])[0]
return np.random.choice(pool) return np.random.choice(pool)
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get training/test data after pipeline
Args:
idx (int): Index of data.
Returns:
dict: Training/test data (with annotation if `test_mode` is set
True).
"""
if self.test_mode: if self.test_mode:
return self.prepare_test_img(idx) return self.prepare_test_img(idx)
while True: while True:
@ -147,6 +191,16 @@ class CustomDataset(Dataset):
return data return data
def prepare_train_img(self, idx): def prepare_train_img(self, idx):
"""Get training data and annotations after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys
introduced by pipeline.
"""
img_info = self.data_infos[idx] img_info = self.data_infos[idx]
ann_info = self.get_ann_info(idx) ann_info = self.get_ann_info(idx)
results = dict(img_info=img_info, ann_info=ann_info) results = dict(img_info=img_info, ann_info=ann_info)
@ -156,6 +210,16 @@ class CustomDataset(Dataset):
return self.pipeline(results) return self.pipeline(results)
def prepare_test_img(self, idx): def prepare_test_img(self, idx):
"""Get testing data after pipeline.
Args:
idx (int): Index of data.
Returns:
dict: Testing data after pipeline with new keys intorduced by
piepline.
"""
img_info = self.data_infos[idx] img_info = self.data_infos[idx]
results = dict(img_info=img_info) results = dict(img_info=img_info)
if self.proposals is not None: if self.proposals is not None:
@ -194,6 +258,7 @@ class CustomDataset(Dataset):
return self.data_infos return self.data_infos
def format_results(self, results, **kwargs): def format_results(self, results, **kwargs):
"""Place holder to format result to dataset specific output"""
pass pass
def evaluate(self, def evaluate(self,
@ -219,6 +284,7 @@ class CustomDataset(Dataset):
scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP. scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP.
Default: None. Default: None.
""" """
if not isinstance(metric, str): if not isinstance(metric, str):
assert len(metric) == 1 assert len(metric) == 1
metric = metric[0] metric = metric[0]

@ -29,6 +29,15 @@ class ConcatDataset(_ConcatDataset):
self.flag = np.concatenate(flags) self.flag = np.concatenate(flags)
def get_cat_ids(self, idx): def get_cat_ids(self, idx):
"""Get category ids of concatenated dataset by index
Args:
idx (int): Index of data.
Returns:
list[int]: All categories in the image of specified index.
"""
if idx < 0: if idx < 0:
if -idx > len(self): if -idx > len(self):
raise ValueError( raise ValueError(
@ -69,9 +78,19 @@ class RepeatDataset(object):
return self.dataset[idx % self._ori_len] return self.dataset[idx % self._ori_len]
def get_cat_ids(self, idx): def get_cat_ids(self, idx):
"""Get category ids of repeat dataset by index
Args:
idx (int): Index of data.
Returns:
list[int]: All categories in the image of specified index.
"""
return self.dataset.get_cat_ids(idx % self._ori_len) return self.dataset.get_cat_ids(idx % self._ori_len)
def __len__(self): def __len__(self):
"""Length after repetition"""
return self.times * self._ori_len return self.times * self._ori_len
@ -128,6 +147,18 @@ class ClassBalancedDataset(object):
self.flag = np.asarray(flags, dtype=np.uint8) self.flag = np.asarray(flags, dtype=np.uint8)
def _get_repeat_factors(self, dataset, repeat_thr): def _get_repeat_factors(self, dataset, repeat_thr):
"""Get repeat factor for each images in the dataset.
Args:
dataset (:obj:`CustomDataset`): The dataset
repeat_thr (float): The threshold of frequency. If an image
contains the categories whose frequency below the threshold,
it would be repeated.
Returns:
list[float]: The repeat factors for each images in the dataset.
"""
# 1. For each category c, compute the fraction # of images # 1. For each category c, compute the fraction # of images
# that contain it: f(c) # that contain it: f(c)
category_freq = defaultdict(int) category_freq = defaultdict(int)
@ -163,4 +194,5 @@ class ClassBalancedDataset(object):
return self.dataset[ori_index] return self.dataset[ori_index]
def __len__(self): def __len__(self):
"""Length after repetition"""
return len(self.repeat_indices) return len(self.repeat_indices)

@ -265,6 +265,15 @@ class LVISDataset(CocoDataset):
'yoke_(animal_equipment)', 'zebra', 'zucchini') 'yoke_(animal_equipment)', 'zebra', 'zucchini')
def load_annotations(self, ann_file): def load_annotations(self, ann_file):
"""Load annotation from lvis style annotation file
Args:
ann_file (str): Path of annotation file.
Returns:
list[dict]: Annotation info from LVIS api.
"""
try: try:
from lvis import LVIS from lvis import LVIS
except ImportError: except ImportError:
@ -298,9 +307,11 @@ class LVISDataset(CocoDataset):
proposal_nums=(100, 300, 1000), proposal_nums=(100, 300, 1000),
iou_thrs=np.arange(0.5, 0.96, 0.05)): iou_thrs=np.arange(0.5, 0.96, 0.05)):
"""Evaluation in LVIS protocol. """Evaluation in LVIS protocol.
Args: Args:
results (list): Testing results of the dataset. results (list[list | tuple]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. metric (str | list[str]): Metrics to be evaluated. Options are
'bbox', 'segm', 'proposal', 'proposal_fast'.
logger (logging.Logger | str | None): Logger used for printing logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None. related information during evaluation. Default: None.
jsonfile_prefix (str | None): jsonfile_prefix (str | None):
@ -311,9 +322,11 @@ class LVISDataset(CocoDataset):
iou_thrs (Sequence[float]): IoU threshold used for evaluating iou_thrs (Sequence[float]): IoU threshold used for evaluating
recalls. If set to a list, the average recall of all IoUs will recalls. If set to a list, the average recall of all IoUs will
also be computed. Default: 0.5. also be computed. Default: 0.5.
Returns: Returns:
dict[str: float] dict[str, float]: LVIS style metrics.
""" """
try: try:
from lvis import LVISResults, LVISEval from lvis import LVISResults, LVISEval
except ImportError: except ImportError:

@ -7,6 +7,12 @@ from ..builder import PIPELINES
@PIPELINES.register_module() @PIPELINES.register_module()
class Compose(object): class Compose(object):
"""Compose multiple transforms sequentially.
Args:
transforms (Sequence[dict | callable]): Sequence of transform object or
config dict to be composed.
"""
def __init__(self, transforms): def __init__(self, transforms):
assert isinstance(transforms, collections.abc.Sequence) assert isinstance(transforms, collections.abc.Sequence)
@ -21,6 +27,15 @@ class Compose(object):
raise TypeError('transform must be callable or a dict') raise TypeError('transform must be callable or a dict')
def __call__(self, data): def __call__(self, data):
"""Call function to apply transforms sequentially.
Args:
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
"""
for t in self.transforms: for t in self.transforms:
data = t(data) data = t(data)
if data is None: if data is None:

@ -13,7 +13,12 @@ def to_tensor(data):
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
:class:`Sequence`, :class:`int` and :class:`float`. :class:`Sequence`, :class:`int` and :class:`float`.
Args:
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
be converted.
""" """
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return data return data
elif isinstance(data, np.ndarray): elif isinstance(data, np.ndarray):
@ -30,11 +35,25 @@ def to_tensor(data):
@PIPELINES.register_module() @PIPELINES.register_module()
class ToTensor(object): class ToTensor(object):
"""Convert some results to :obj:`torch.Tensor` by given keys.
Args:
keys (Sequence[str]): Keys that need to be converted to Tensor.
"""
def __init__(self, keys): def __init__(self, keys):
self.keys = keys self.keys = keys
def __call__(self, results): def __call__(self, results):
"""Call function to convert data in results to :obj:`torch.Tensor`.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data converted
to :obj:`torch.Tensor`.
"""
for key in self.keys: for key in self.keys:
results[key] = to_tensor(results[key]) results[key] = to_tensor(results[key])
return results return results
@ -45,11 +64,30 @@ class ToTensor(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class ImageToTensor(object): class ImageToTensor(object):
"""Convert image to :obj:`torch.Tensor` by given keys.
The dimension order of input image is (H, W, C). The pipeline will convert
it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
(1, H, W).
Args:
keys (Sequence[str]): Key of images to be converted to Tensor.
"""
def __init__(self, keys): def __init__(self, keys):
self.keys = keys self.keys = keys
def __call__(self, results): def __call__(self, results):
"""Call function to convert image in results to :obj:`torch.Tensor`
and transpose the channel order.
Args:
results (dict): Result dict contains the image data to convert.
Returns:
dict: The result dict contains the image converted
to :obj:`torch.Tensor` and transposed to (C, H, W) order.
"""
for key in self.keys: for key in self.keys:
img = results[key] img = results[key]
if len(img.shape) < 3: if len(img.shape) < 3:
@ -63,12 +101,27 @@ class ImageToTensor(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class Transpose(object): class Transpose(object):
"""Transpose some results by given keys.
Args:
keys (Sequence[str]): Keys of results to be transposed.
order (Sequence[int]): Order of transpose.
"""
def __init__(self, keys, order): def __init__(self, keys, order):
self.keys = keys self.keys = keys
self.order = order self.order = order
def __call__(self, results): def __call__(self, results):
"""Call function to transpose the channel order of data in results.
Args:
results (dict): Result dict contains the data to transpose.
Returns:
dict: The result dict contains the data transposed to
``self.order``.
"""
for key in self.keys: for key in self.keys:
results[key] = results[key].transpose(self.order) results[key] = results[key].transpose(self.order)
return results return results
@ -80,6 +133,15 @@ class Transpose(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class ToDataContainer(object): class ToDataContainer(object):
"""Convert results to :obj:`mmcv.DataContainer` by given fields.
Args:
fields (Sequence[dict]): Each field is a dict like
``dict(key='xxx', **kwargs)``. The ``key`` in result will
be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
Default: ``(dict(key='img', stack=True), dict(key='gt_bboxes'),
dict(key='gt_labels'))``.
"""
def __init__(self, def __init__(self,
fields=(dict(key='img', stack=True), dict(key='gt_bboxes'), fields=(dict(key='img', stack=True), dict(key='gt_bboxes'),
@ -87,6 +149,17 @@ class ToDataContainer(object):
self.fields = fields self.fields = fields
def __call__(self, results): def __call__(self, results):
"""Call function to convert data in results to
:obj:`mmcv.DataContainer`.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data converted to
:obj:`mmcv.DataContainer`.
"""
for field in self.fields: for field in self.fields:
field = field.copy() field = field.copy()
key = field.pop('key') key = field.pop('key')
@ -116,6 +189,16 @@ class DefaultFormatBundle(object):
""" """
def __call__(self, results): def __call__(self, results):
"""Call function to transform and format common fields in results.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data that is formatted with
default bundle.
"""
if 'img' in results: if 'img' in results:
img = results['img'] img = results['img']
if len(img.shape) < 3: if len(img.shape) < 3:
@ -167,6 +250,14 @@ class Collect(object):
- mean - per channel mean subtraction - mean - per channel mean subtraction
- std - per channel std divisor - std - per channel std divisor
- to_rgb - bool indicating if bgr was converted to rgb - to_rgb - bool indicating if bgr was converted to rgb
Args:
keys (Sequence[str]): Keys of results to be collected in ``data``.
meta_keys (Sequence[str], optional): Meta keys to be converted to
``mmcv.DataContainer`` and collected in ``data[img_metas]``.
Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'img_norm_cfg')``
""" """
def __init__(self, def __init__(self,
@ -178,6 +269,18 @@ class Collect(object):
self.meta_keys = meta_keys self.meta_keys = meta_keys
def __call__(self, results): def __call__(self, results):
"""Call function to collect keys in results. The keys in ``meta_keys``
will be converted to :obj:mmcv.DataContainer.
Args:
results (dict): Result dict contains the data to collect.
Returns:
dict: The result dict contains the following keys
- keys in``self.keys``
- ``img_metas``
"""
data = {} data = {}
img_meta = {} img_meta = {}
for key in self.meta_keys: for key in self.meta_keys:
@ -215,6 +318,16 @@ class WrapFieldsToLists(object):
""" """
def __call__(self, results): def __call__(self, results):
"""Call function to wrap fields into lists.
Args:
results (dict): Result dict contains the data to wrap.
Returns:
dict: The result dict where value of ``self.keys`` are wrapped
into list.
"""
# Wrap dict fields into lists # Wrap dict fields into lists
for key, val in results.items(): for key, val in results.items():
results[key] = [val] results[key] = [val]

@ -38,6 +38,15 @@ class LoadImageFromFile(object):
self.file_client = None self.file_client = None
def __call__(self, results): def __call__(self, results):
"""Call functions to load image and get image meta information.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded image and meta information.
"""
if self.file_client is None: if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args) self.file_client = mmcv.FileClient(**self.file_client_args)
@ -107,6 +116,16 @@ class LoadMultiChannelImageFromFiles(object):
self.file_client = None self.file_client = None
def __call__(self, results): def __call__(self, results):
"""Call functions to load multiple images and get images meta
information.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded images and meta information.
"""
if self.file_client is None: if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args) self.file_client = mmcv.FileClient(**self.file_client_args)
@ -151,7 +170,7 @@ class LoadMultiChannelImageFromFiles(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class LoadAnnotations(object): class LoadAnnotations(object):
"""Load annotations. """Load mutiple types of annotations.
Args: Args:
with_bbox (bool): Whether to parse and load the bbox annotation. with_bbox (bool): Whether to parse and load the bbox annotation.
@ -185,6 +204,15 @@ class LoadAnnotations(object):
self.file_client = None self.file_client = None
def _load_bboxes(self, results): def _load_bboxes(self, results):
"""Private function to load bounding box annotations.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded bounding box annotations.
"""
ann_info = results['ann_info'] ann_info = results['ann_info']
results['gt_bboxes'] = ann_info['bboxes'].copy() results['gt_bboxes'] = ann_info['bboxes'].copy()
@ -196,10 +224,31 @@ class LoadAnnotations(object):
return results return results
def _load_labels(self, results): def _load_labels(self, results):
"""Private function to load label annotations.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded label annotations.
"""
results['gt_labels'] = results['ann_info']['labels'].copy() results['gt_labels'] = results['ann_info']['labels'].copy()
return results return results
def _poly2mask(self, mask_ann, img_h, img_w): def _poly2mask(self, mask_ann, img_h, img_w):
"""Private function to convert masks represented with polygon to
bitmaps.
Args:
mask_ann (list | dict): Polygon mask annotation input.
img_h (int): The height of output mask.
img_w (int): The width of output mask.
Returns:
numpy.ndarray: The decode bitmap mask of shape (img_h, img_w).
"""
if isinstance(mask_ann, list): if isinstance(mask_ann, list):
# polygon -- a single object might consist of multiple parts # polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code # we merge all parts into one mask rle code
@ -218,11 +267,12 @@ class LoadAnnotations(object):
"""Convert polygons to list of ndarray and filter invalid polygons. """Convert polygons to list of ndarray and filter invalid polygons.
Args: Args:
polygons (list[list]): polygons of one instance. polygons (list[list]): Polygons of one instance.
Returns: Returns:
list[ndarray]: processed polygons. list[numpy.ndarray]: Processed polygons.
""" """
polygons = [np.array(p) for p in polygons] polygons = [np.array(p) for p in polygons]
valid_polygons = [] valid_polygons = []
for polygon in polygons: for polygon in polygons:
@ -231,6 +281,17 @@ class LoadAnnotations(object):
return valid_polygons return valid_polygons
def _load_masks(self, results): def _load_masks(self, results):
"""Private function to load mask annotations.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded mask annotations.
If ``self.poly2mask`` is set ``True``, `gt_mask` will contain
:obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.
"""
h, w = results['img_info']['height'], results['img_info']['width'] h, w = results['img_info']['height'], results['img_info']['width']
gt_masks = results['ann_info']['masks'] gt_masks = results['ann_info']['masks']
if self.poly2mask: if self.poly2mask:
@ -245,6 +306,15 @@ class LoadAnnotations(object):
return results return results
def _load_semantic_seg(self, results): def _load_semantic_seg(self, results):
"""Private function to load semantic segmentation annotations.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns:
dict: The dict contains loaded semantic segmentation annotations.
"""
if self.file_client is None: if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args) self.file_client = mmcv.FileClient(**self.file_client_args)
@ -257,6 +327,16 @@ class LoadAnnotations(object):
return results return results
def __call__(self, results): def __call__(self, results):
"""Call function to load multiple types annotations
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded bounding box, label, mask and
semantic segmentation annotations.
"""
if self.with_bbox: if self.with_bbox:
results = self._load_bboxes(results) results = self._load_bboxes(results)
if results is None: if results is None:
@ -282,11 +362,28 @@ class LoadAnnotations(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class LoadProposals(object): class LoadProposals(object):
"""Load proposal pipeline.
Required key is "proposals". Updated keys are "proposals", "bbox_fields".
Args:
num_max_proposals (int, optional): Maximum number of proposals to load.
If not specified, all proposals will be loaded.
"""
def __init__(self, num_max_proposals=None): def __init__(self, num_max_proposals=None):
self.num_max_proposals = num_max_proposals self.num_max_proposals = num_max_proposals
def __call__(self, results): def __call__(self, results):
"""Call function to load proposals from file.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded proposal annotations.
"""
proposals = results['proposals'] proposals = results['proposals']
if proposals.shape[1] not in (4, 5): if proposals.shape[1] not in (4, 5):
raise AssertionError( raise AssertionError(

@ -10,6 +10,34 @@ from .compose import Compose
class MultiScaleFlipAug(object): class MultiScaleFlipAug(object):
"""Test-time augmentation with multiple scales and flipping """Test-time augmentation with multiple scales and flipping
An example configuration is as followed:
.. code-block::
img_scale=[(1333, 400), (1333, 800)],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
]
After MultiScaleFLipAug with above configuration, the results are wrapped
into lists of the same length as followed:
.. code-block::
dict(
img=[...],
img_shape=[...],
scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
flip=[False, True, False, True]
...
)
Args: Args:
transforms (list[dict]): Transforms to apply in each augmentation. transforms (list[dict]): Transforms to apply in each augmentation.
img_scale (tuple | list[tuple]: Images scales for resizing. img_scale (tuple | list[tuple]: Images scales for resizing.
@ -42,6 +70,16 @@ class MultiScaleFlipAug(object):
'flip has no effect when RandomFlip is not in transforms') 'flip has no effect when RandomFlip is not in transforms')
def __call__(self, results): def __call__(self, results):
"""Call function to apply test time augment transforms on results.
Args:
results (dict): Result dict contains the data to transform.
Returns:
dict[str: list]: The augmented data, where each value is wrapped
into a list.
"""
aug_data = [] aug_data = []
flip_aug = [False, True] if self.flip else [False] flip_aug = [False, True] if self.flip else [False]
for scale in self.img_scale: for scale in self.img_scale:

@ -75,6 +75,17 @@ class Resize(object):
@staticmethod @staticmethod
def random_select(img_scales): def random_select(img_scales):
"""Randomly select an img_scale from given candidates.
Args:
img_scales (list[tuple]): Images scales for selection.
Returns:
(tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
where ``img_scale`` is the selected image scale and
``scale_idx`` is the selected index in the given candidates.
"""
assert mmcv.is_list_of(img_scales, tuple) assert mmcv.is_list_of(img_scales, tuple)
scale_idx = np.random.randint(len(img_scales)) scale_idx = np.random.randint(len(img_scales))
img_scale = img_scales[scale_idx] img_scale = img_scales[scale_idx]
@ -82,6 +93,19 @@ class Resize(object):
@staticmethod @staticmethod
def random_sample(img_scales): def random_sample(img_scales):
"""Randomly sample an img_scale when ``multiscale_mode=='range'``.
Args:
img_scales (list[tuple]): Images scale range for sampling.
There must be two tuples in img_scales, which specify the lower
and uper bound of image scales.
Returns:
(tuple, None): Returns a tuple ``(img_scale, None)``, where
``img_scale`` is sampled scale and None is just a placeholder
to be consistent with :func:`random_select`.
"""
assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
img_scale_long = [max(s) for s in img_scales] img_scale_long = [max(s) for s in img_scales]
img_scale_short = [min(s) for s in img_scales] img_scale_short = [min(s) for s in img_scales]
@ -96,6 +120,24 @@ class Resize(object):
@staticmethod @staticmethod
def random_sample_ratio(img_scale, ratio_range): def random_sample_ratio(img_scale, ratio_range):
"""Randomly sample an img_scale when ``ratio_range`` is specified.
A ratio will be randomly sampled from the range specified by
``ratio_range``. Then it would be multiplied with ``img_scale`` to
generate sampled scale.
Args:
img_scale (tuple): Images scale base to multiply with ratio.
ratio_range (tuple[float]): The minimum and maximum ratio to scale
the ``img_scale``.
Returns:
(tuple, None): Returns a tuple ``(scale, None)``, where
``scale`` is sampled ratio multiplied with ``img_scale`` and
None is just a placeholder to be consistent with
:func:`random_select`.
"""
assert isinstance(img_scale, tuple) and len(img_scale) == 2 assert isinstance(img_scale, tuple) and len(img_scale) == 2
min_ratio, max_ratio = ratio_range min_ratio, max_ratio = ratio_range
assert min_ratio <= max_ratio assert min_ratio <= max_ratio
@ -104,6 +146,23 @@ class Resize(object):
return scale, None return scale, None
def _random_scale(self, results): def _random_scale(self, results):
"""Randomly sample an img_scale according to ``ratio_range`` and
``multiscale_mode``.
If ``ratio_range`` is specified, a ratio will be sampled and be
multiplied with ``img_scale``.
If multiple scales are specified by ``img_scale``, a scale will be
sampled according to ``multiscale_mode``.
Otherwise, single scale will be used.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns:
dict: Two new keys 'scale` and 'scale_idx` are added into
``results``, which would be used by subsequent pipelines.
"""
if self.ratio_range is not None: if self.ratio_range is not None:
scale, scale_idx = self.random_sample_ratio( scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range) self.img_scale[0], self.ratio_range)
@ -120,6 +179,7 @@ class Resize(object):
results['scale_idx'] = scale_idx results['scale_idx'] = scale_idx
def _resize_img(self, results): def _resize_img(self, results):
"""Resize images with ``results['scale']``."""
for key in results.get('img_fields', ['img']): for key in results.get('img_fields', ['img']):
if self.keep_ratio: if self.keep_ratio:
img, scale_factor = mmcv.imrescale( img, scale_factor = mmcv.imrescale(
@ -144,6 +204,7 @@ class Resize(object):
results['keep_ratio'] = self.keep_ratio results['keep_ratio'] = self.keep_ratio
def _resize_bboxes(self, results): def _resize_bboxes(self, results):
"""Resize bounding boxes with ``results['scale_factor']``."""
img_shape = results['img_shape'] img_shape = results['img_shape']
for key in results.get('bbox_fields', []): for key in results.get('bbox_fields', []):
bboxes = results[key] * results['scale_factor'] bboxes = results[key] * results['scale_factor']
@ -152,6 +213,7 @@ class Resize(object):
results[key] = bboxes results[key] = bboxes
def _resize_masks(self, results): def _resize_masks(self, results):
"""Resize masks with ``results['scale']``"""
for key in results.get('mask_fields', []): for key in results.get('mask_fields', []):
if results[key] is None: if results[key] is None:
continue continue
@ -161,6 +223,7 @@ class Resize(object):
results[key] = results[key].resize(results['img_shape'][:2]) results[key] = results[key].resize(results['img_shape'][:2])
def _resize_seg(self, results): def _resize_seg(self, results):
"""Resize semantic segmentation map with ``results['scale']``."""
for key in results.get('seg_fields', []): for key in results.get('seg_fields', []):
if self.keep_ratio: if self.keep_ratio:
gt_seg = mmcv.imrescale( gt_seg = mmcv.imrescale(
@ -171,6 +234,17 @@ class Resize(object):
results['gt_semantic_seg'] = gt_seg results['gt_semantic_seg'] = gt_seg
def __call__(self, results): def __call__(self, results):
"""Call function to resize images, bounding boxes, masks, semantic
segmentation map.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
'keep_ratio' keys are added into result dict.
"""
if 'scale' not in results: if 'scale' not in results:
self._random_scale(results) self._random_scale(results)
self._resize_img(results) self._resize_img(results)
@ -197,22 +271,31 @@ class RandomFlip(object):
method. method.
Args: Args:
flip_ratio (float, optional): The flipping probability. flip_ratio (float, optional): The flipping probability. Default: None.
direction(str, optional): The flipping direction. Options are
'horizontal' and 'vertical'. Default: 'horizontal'.
""" """
def __init__(self, flip_ratio=0.5, direction='horizontal'): def __init__(self, flip_ratio=None, direction='horizontal'):
assert flip_ratio >= 0 and flip_ratio <= 1
assert direction in ['horizontal', 'vertical']
self.flip_ratio = flip_ratio self.flip_ratio = flip_ratio
self.direction = direction self.direction = direction
if flip_ratio is not None:
assert flip_ratio >= 0 and flip_ratio <= 1
assert direction in ['horizontal', 'vertical']
def bbox_flip(self, bboxes, img_shape, direction): def bbox_flip(self, bboxes, img_shape, direction):
"""Flip bboxes horizontally. """Flip bboxes horizontally.
Args: Args:
bboxes(ndarray): shape (..., 4*k) bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k)
img_shape(tuple): (height, width) img_shape (tuple[int]): Image shape (height, width)
direction (str): Flip direction. Options are 'horizontal',
'vertical'.
Returns:
numpy.ndarray: Flipped bounding boxes.
""" """
assert bboxes.shape[-1] % 4 == 0 assert bboxes.shape[-1] % 4 == 0
flipped = bboxes.copy() flipped = bboxes.copy()
if direction == 'horizontal': if direction == 'horizontal':
@ -228,6 +311,17 @@ class RandomFlip(object):
return flipped return flipped
def __call__(self, results): def __call__(self, results):
"""Call function to flip bounding boxes, masks, semantic segmentation
maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Flipped results, 'flip', 'flip_direction' keys are added into
result dict.
"""
if 'flip' not in results: if 'flip' not in results:
flip = True if np.random.rand() < self.flip_ratio else False flip = True if np.random.rand() < self.flip_ratio else False
results['flip'] = flip results['flip'] = flip
@ -263,6 +357,7 @@ class Pad(object):
There are two padding modes: (1) pad to a fixed size and (2) pad to the There are two padding modes: (1) pad to a fixed size and (2) pad to the
minimum size that is divisible by some number. minimum size that is divisible by some number.
Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
Args: Args:
size (tuple, optional): Fixed padding size. size (tuple, optional): Fixed padding size.
@ -279,6 +374,7 @@ class Pad(object):
assert size is None or size_divisor is None assert size is None or size_divisor is None
def _pad_img(self, results): def _pad_img(self, results):
"""Pad images according to ``self.size``."""
for key in results.get('img_fields', ['img']): for key in results.get('img_fields', ['img']):
if self.size is not None: if self.size is not None:
padded_img = mmcv.impad(results[key], self.size, self.pad_val) padded_img = mmcv.impad(results[key], self.size, self.pad_val)
@ -291,15 +387,27 @@ class Pad(object):
results['pad_size_divisor'] = self.size_divisor results['pad_size_divisor'] = self.size_divisor
def _pad_masks(self, results): def _pad_masks(self, results):
"""Pad masks according to ``results['pad_shape']``."""
pad_shape = results['pad_shape'][:2] pad_shape = results['pad_shape'][:2]
for key in results.get('mask_fields', []): for key in results.get('mask_fields', []):
results[key] = results[key].pad(pad_shape, pad_val=self.pad_val) results[key] = results[key].pad(pad_shape, pad_val=self.pad_val)
def _pad_seg(self, results): def _pad_seg(self, results):
"""Pad semantic segmentation map according to
``results['pad_shape']``."""
for key in results.get('seg_fields', []): for key in results.get('seg_fields', []):
results[key] = mmcv.impad(results[key], results['pad_shape'][:2]) results[key] = mmcv.impad(results[key], results['pad_shape'][:2])
def __call__(self, results): def __call__(self, results):
"""Call function to pad images, masks, semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Updated result dict.
"""
self._pad_img(results) self._pad_img(results)
self._pad_masks(results) self._pad_masks(results)
self._pad_seg(results) self._pad_seg(results)
@ -317,6 +425,8 @@ class Pad(object):
class Normalize(object): class Normalize(object):
"""Normalize the image. """Normalize the image.
Added key is "img_norm_cfg".
Args: Args:
mean (sequence): Mean values of 3 channels. mean (sequence): Mean values of 3 channels.
std (sequence): Std values of 3 channels. std (sequence): Std values of 3 channels.
@ -330,6 +440,16 @@ class Normalize(object):
self.to_rgb = to_rgb self.to_rgb = to_rgb
def __call__(self, results): def __call__(self, results):
"""Call function to normalize images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Normalized results, 'img_norm_cfg' key is added into
result dict.
"""
for key in results.get('img_fields', ['img']): for key in results.get('img_fields', ['img']):
results[key] = mmcv.imnormalize(results[key], self.mean, self.std, results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
self.to_rgb) self.to_rgb)
@ -374,6 +494,17 @@ class RandomCrop(object):
} }
def __call__(self, results): def __call__(self, results):
"""Call function to randomly crop images, bounding boxes, masks,
semantic segmentation maps.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Randomly cropped results, 'img_shape' key in result dict is
updated according to crop size.
"""
for key in results.get('img_fields', ['img']): for key in results.get('img_fields', ['img']):
img = results[key] img = results[key]
margin_h = max(img.shape[0] - self.crop_size[0], 0) margin_h = max(img.shape[0] - self.crop_size[0], 0)
@ -444,6 +575,15 @@ class SegRescale(object):
self.scale_factor = scale_factor self.scale_factor = scale_factor
def __call__(self, results): def __call__(self, results):
"""Call function to scale the semantic segmentation map
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with semantic segmentation map scaled.
"""
for key in results.get('seg_fields', []): for key in results.get('seg_fields', []):
if self.scale_factor != 1: if self.scale_factor != 1:
results[key] = mmcv.imrescale( results[key] = mmcv.imrescale(
@ -487,6 +627,15 @@ class PhotoMetricDistortion(object):
self.hue_delta = hue_delta self.hue_delta = hue_delta
def __call__(self, results): def __call__(self, results):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
if 'img_fields' in results: if 'img_fields' in results:
assert results['img_fields'] == ['img'], \ assert results['img_fields'] == ['img'], \
'Only single img_fields is allowed' 'Only single img_fields is allowed'
@ -582,6 +731,15 @@ class Expand(object):
self.prob = prob self.prob = prob
def __call__(self, results): def __call__(self, results):
"""Call function to expand images, bounding boxes.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images, bounding boxes expanded
"""
if random.uniform(0, 1) > self.prob: if random.uniform(0, 1) > self.prob:
return results return results
@ -661,6 +819,17 @@ class MinIoURandomCrop(object):
} }
def __call__(self, results): def __call__(self, results):
"""Call function to crop images and bounding boxes with minimum IoU
constraint.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images and bounding boxes cropped,
'img_shape' key is updated.
"""
if 'img_fields' in results: if 'img_fields' in results:
assert results['img_fields'] == ['img'], \ assert results['img_fields'] == ['img'], \
'Only single img_fields is allowed' 'Only single img_fields is allowed'
@ -751,12 +920,30 @@ class MinIoURandomCrop(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class Corrupt(object): class Corrupt(object):
"""Corruption augmentation.
Corruption transforms implemented based on
`imagecorruptions <https://github.com/bethgelab/imagecorruptions>`_.
Args:
corruption (str): Corruption name.
severity (int, optional): The severity of corruption. Default: 1.
"""
def __init__(self, corruption, severity=1): def __init__(self, corruption, severity=1):
self.corruption = corruption self.corruption = corruption
self.severity = severity self.severity = severity
def __call__(self, results): def __call__(self, results):
"""Call function to corrupt image.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images corrupted.
"""
if corrupt is None: if corrupt is None:
raise RuntimeError('imagecorruptions is not installed') raise RuntimeError('imagecorruptions is not installed')
if 'img_fields' in results: if 'img_fields' in results:
@ -777,6 +964,46 @@ class Corrupt(object):
@PIPELINES.register_module() @PIPELINES.register_module()
class Albu(object): class Albu(object):
"""Albumentation augmentation.
Adds custom transformations from Albumentations library.
Please, visit `https://albumentations.readthedocs.io`
to get more information.
An example of ``transforms`` is as followed:
.. code-block::
[
dict(
type='ShiftScaleRotate',
shift_limit=0.0625,
scale_limit=0.0,
rotate_limit=0,
interpolation=1,
p=0.5),
dict(
type='RandomBrightnessContrast',
brightness_limit=[0.1, 0.3],
contrast_limit=[0.1, 0.3],
p=0.2),
dict(type='ChannelShuffle', p=0.1),
dict(
type='OneOf',
transforms=[
dict(type='Blur', blur_limit=3, p=1.0),
dict(type='MedianBlur', blur_limit=3, p=1.0)
],
p=0.1),
]
Args:
transforms (list[dict]): A list of albu transformations
bbox_params (dict): Bbox_params for albumentation `Compose`
keymap (dict): Contains {'input key':'albumentation-style key'}
skip_img_without_anno (bool): Whether to skip the image if no ann left
after aug
"""
def __init__(self, def __init__(self,
transforms, transforms,
@ -784,17 +1011,6 @@ class Albu(object):
keymap=None, keymap=None,
update_pad_shape=False, update_pad_shape=False,
skip_img_without_anno=False): skip_img_without_anno=False):
"""
Adds custom transformations from Albumentations lib.
Please, visit `https://albumentations.readthedocs.io`
to get more information.
transforms (list): list of albu transformations
bbox_params (dict): bbox_params for albumentation `Compose`
keymap (dict): contains {'input key':'albumentation-style key'}
skip_img_without_anno (bool): whether to skip the image
if no ann left after aug
"""
if Compose is None: if Compose is None:
raise RuntimeError('albumentations is not installed') raise RuntimeError('albumentations is not installed')
@ -835,6 +1051,7 @@ class Albu(object):
Returns: Returns:
obj: The constructed object. obj: The constructed object.
""" """
assert isinstance(cfg, dict) and 'type' in cfg assert isinstance(cfg, dict) and 'type' in cfg
args = cfg.copy() args = cfg.copy()
@ -869,6 +1086,7 @@ class Albu(object):
Returns: Returns:
dict: new dict. dict: new dict.
""" """
updated_dict = {} updated_dict = {}
for k, v in zip(d.keys(), d.values()): for k, v in zip(d.keys(), d.values()):
new_k = keymap.get(k, k) new_k = keymap.get(k, k)

@ -27,6 +27,28 @@ class VOCDataset(XMLDataset):
proposal_nums=(100, 300, 1000), proposal_nums=(100, 300, 1000),
iou_thr=0.5, iou_thr=0.5,
scale_ranges=None): scale_ranges=None):
"""Evaluate in VOC protocol.
Args:
results (list[list | tuple]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. Options are
'mAP', 'recall'.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Default: None.
proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000.
Default: (100, 300, 1000).
iou_thr (float | list[float]): IoU threshold. It must be a float
when evaluating mAP, and can be a list when evaluating recall.
Default: 0.5.
scale_ranges (list[tuple], optional): Scale ranges for evaluating
mAP. If not specified, all bounding boxes would be included in
evaluation. Default: None.
Returns:
dict[str, float]: AP/recall metrics.
"""
if not isinstance(metric, str): if not isinstance(metric, str):
assert len(metric) == 1 assert len(metric) == 1
metric = metric[0] metric = metric[0]

@ -20,6 +20,15 @@ class WIDERFaceDataset(XMLDataset):
super(WIDERFaceDataset, self).__init__(**kwargs) super(WIDERFaceDataset, self).__init__(**kwargs)
def load_annotations(self, ann_file): def load_annotations(self, ann_file):
"""Load annotation from WIDERFace XML style annotation file.
Args:
ann_file (str): Path of XML file.
Returns:
list[dict]: Annotation info from XML file.
"""
data_infos = [] data_infos = []
img_ids = mmcv.list_from_file(ann_file) img_ids = mmcv.list_from_file(ann_file)
for img_id in img_ids: for img_id in img_ids:

@ -11,6 +11,13 @@ from .custom import CustomDataset
@DATASETS.register_module() @DATASETS.register_module()
class XMLDataset(CustomDataset): class XMLDataset(CustomDataset):
"""XML dataset for detection.
Args:
min_size (int | float, optional): The minimum size of bounding
boxes in the images. If the size of a bounding box is less than
``min_size``, it would be add to ignored field.
"""
def __init__(self, min_size=None, **kwargs): def __init__(self, min_size=None, **kwargs):
super(XMLDataset, self).__init__(**kwargs) super(XMLDataset, self).__init__(**kwargs)
@ -18,6 +25,15 @@ class XMLDataset(CustomDataset):
self.min_size = min_size self.min_size = min_size
def load_annotations(self, ann_file): def load_annotations(self, ann_file):
"""Load annotation from XML style ann_file.
Args:
ann_file (str): Path of XML file.
Returns:
list[dict]: Annotation info from XML file.
"""
data_infos = [] data_infos = []
img_ids = mmcv.list_from_file(ann_file) img_ids = mmcv.list_from_file(ann_file)
for img_id in img_ids: for img_id in img_ids:
@ -43,8 +59,7 @@ class XMLDataset(CustomDataset):
return data_infos return data_infos
def get_subset_by_classes(self): def get_subset_by_classes(self):
"""Filter imgs by user-defined categories """Filter imgs by user-defined categories"""
"""
subset_data_infos = [] subset_data_infos = []
for data_info in self.data_infos: for data_info in self.data_infos:
img_id = data_info['id'] img_id = data_info['id']
@ -61,6 +76,15 @@ class XMLDataset(CustomDataset):
return subset_data_infos return subset_data_infos
def get_ann_info(self, idx): def get_ann_info(self, idx):
"""Get annotation from XML file by index.
Args:
idx (int): Index of data.
Returns:
dict: Annotation info of specified index.
"""
img_id = self.data_infos[idx]['id'] img_id = self.data_infos[idx]['id']
xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml') xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml')
tree = ET.parse(xml_path) tree = ET.parse(xml_path)
@ -117,6 +141,15 @@ class XMLDataset(CustomDataset):
return ann return ann
def get_cat_ids(self, idx): def get_cat_ids(self, idx):
"""Get category ids in XML file by index.
Args:
idx (int): Index of data.
Returns:
list[int]: All categories in the image of specified index.
"""
cat_ids = [] cat_ids = []
img_id = self.data_infos[idx]['id'] img_id = self.data_infos[idx]['id']
xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml') xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml')

@ -14,8 +14,8 @@ class GenericRoIExtractor(BaseRoIExtractor):
Args: Args:
aggregation (str): The method to aggregate multiple feature maps. aggregation (str): The method to aggregate multiple feature maps.
Options are 'sum', 'concat'. Default: 'sum'. Options are 'sum', 'concat'. Default: 'sum'.
pre_cfg (dict|None): Specify pre-processing modules. Default: None. pre_cfg (dict | None): Specify pre-processing modules. Default: None.
post_cfg (dict|None): Specify post-processing modules. Default: None. post_cfg (dict | None): Specify post-processing modules. Default: None.
kwargs (keyword arguments): Arguments that are the same kwargs (keyword arguments): Arguments that are the same
as :class:`BaseRoIExtractor`. as :class:`BaseRoIExtractor`.
""" """

Loading…
Cancel
Save