[Enhance]: Check dtype in transform unit tests. (#5969)

* add dtype to bbox unit test

* add dtype to unit test

* add docstring

* add label and check dtype
pull/5946/head^2
RangiLyu 3 years ago committed by GitHub
parent 13880ae5fd
commit f05f843da9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      tests/test_data/test_pipelines/test_transform/__init__.py
  2. 32
      tests/test_data/test_pipelines/test_transform/test_img_augment.py
  3. 69
      tests/test_data/test_pipelines/test_transform/test_rotate.py
  4. 78
      tests/test_data/test_pipelines/test_transform/test_shear.py
  5. 111
      tests/test_data/test_pipelines/test_transform/test_transform.py
  6. 78
      tests/test_data/test_pipelines/test_transform/utils.py

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import check_result_same, construct_toy_data, create_random_bboxes
__all__ = ['create_random_bboxes', 'construct_toy_data', 'check_result_same']

@ -6,38 +6,8 @@ import numpy as np
from mmcv.utils import build_from_cfg
from numpy.testing import assert_array_equal
from mmdet.core.mask import BitmapMasks, PolygonMasks
from mmdet.datasets.builder import PIPELINES
def construct_toy_data(poly2mask=True):
img = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.uint8)
img = np.stack([img, img, img], axis=-1)
results = dict()
# image
results['img'] = img
results['img_shape'] = img.shape
results['img_fields'] = ['img']
# bboxes
results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
results['gt_bboxes'] = np.array([[0., 0., 2., 1.]], dtype=np.float32)
results['gt_bboxes_ignore'] = np.array([[2., 0., 3., 1.]],
dtype=np.float32)
# labels
results['gt_labels'] = np.array([1], dtype=np.int64)
# masks
results['mask_fields'] = ['gt_masks']
if poly2mask:
gt_masks = np.array([[0, 1, 1, 0], [0, 1, 0, 0]],
dtype=np.uint8)[None, :, :]
results['gt_masks'] = BitmapMasks(gt_masks, 2, 4)
else:
raw_masks = [[np.array([1, 0, 2, 0, 2, 1, 1, 1], dtype=np.float)]]
results['gt_masks'] = PolygonMasks(raw_masks, 2, 4)
# segmentations
results['seg_fields'] = ['gt_semantic_seg']
results['gt_semantic_seg'] = img[..., 0]
return results
from .utils import construct_toy_data
def test_adjust_color():

@ -7,60 +7,7 @@ from mmcv.utils import build_from_cfg
from mmdet.core.mask import BitmapMasks, PolygonMasks
from mmdet.datasets.builder import PIPELINES
def construct_toy_data(poly2mask=True):
img = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.uint8)
img = np.stack([img, img, img], axis=-1)
results = dict()
# image
results['img'] = img
results['img_shape'] = img.shape
results['img_fields'] = ['img']
# bboxes
results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
results['gt_bboxes'] = np.array([[0., 0., 2., 1.]], dtype=np.float32)
results['gt_bboxes_ignore'] = np.array([[2., 0., 3., 1.]],
dtype=np.float32)
# labels
results['gt_labels'] = np.array([1], dtype=np.int64)
# masks
results['mask_fields'] = ['gt_masks']
if poly2mask:
gt_masks = np.array([[0, 1, 1, 0], [0, 1, 0, 0]],
dtype=np.uint8)[None, :, :]
results['gt_masks'] = BitmapMasks(gt_masks, 2, 4)
else:
raw_masks = [[np.array([0, 0, 2, 0, 2, 1, 0, 1], dtype=np.float)]]
results['gt_masks'] = PolygonMasks(raw_masks, 2, 4)
# segmentations
results['seg_fields'] = ['gt_semantic_seg']
results['gt_semantic_seg'] = img[..., 0]
return results
def _check_fields(results, results_rotated, keys):
for key in keys:
if isinstance(results[key], (BitmapMasks, PolygonMasks)):
assert np.equal(results[key].to_ndarray(),
results_rotated[key].to_ndarray()).all()
else:
assert np.equal(results[key], results_rotated[key]).all()
def check_rotate(results, results_rotated):
# check image
_check_fields(results, results_rotated, results.get('img_fields', ['img']))
# check bboxes
_check_fields(results, results_rotated, results.get('bbox_fields', []))
# check masks
_check_fields(results, results_rotated, results.get('mask_fields', []))
# check segmentations
_check_fields(results, results_rotated, results.get('seg_fields', []))
# _check gt_labels
if 'gt_labels' in results:
assert np.equal(results['gt_labels'],
results_rotated['gt_labels']).all()
from .utils import check_result_same, construct_toy_data
def test_rotate():
@ -105,14 +52,14 @@ def test_rotate():
)
rotate_module = build_from_cfg(transform, PIPELINES)
results_wo_rotate = rotate_module(copy.deepcopy(results))
check_rotate(results, results_wo_rotate)
check_result_same(results, results_wo_rotate)
# test case when no rotate aug (prob<=0)
transform = dict(
type='Rotate', level=10, prob=0., img_fill_val=img_fill_val, scale=0.6)
rotate_module = build_from_cfg(transform, PIPELINES)
results_wo_rotate = rotate_module(copy.deepcopy(results))
check_rotate(results, results_wo_rotate)
check_result_same(results, results_wo_rotate)
# test clockwise rotation with angle 90
results = construct_toy_data()
@ -140,14 +87,14 @@ def test_rotate():
results_gt['gt_semantic_seg'] = np.array(
[[255, 6, 2, 255], [255, 7, 3,
255]]).astype(results['gt_semantic_seg'].dtype)
check_rotate(results_gt, results_rotated)
check_result_same(results_gt, results_rotated)
# test clockwise rotation with angle 90, PolygonMasks
results = construct_toy_data(poly2mask=False)
results_rotated = rotate_module(copy.deepcopy(results))
gt_masks = [[np.array([2, 0, 2, 1, 1, 1, 1, 0], dtype=np.float)]]
results_gt['gt_masks'] = PolygonMasks(gt_masks, 2, 4)
check_rotate(results_gt, results_rotated)
check_result_same(results_gt, results_rotated)
# test counter-clockwise roatation with angle 90,
# and specify the ratation center
@ -183,7 +130,7 @@ def test_rotate():
gt_seg = (np.ones((h, w)) * 255).astype(results['gt_semantic_seg'].dtype)
gt_seg[0, 0], gt_seg[0, 1] = 1, 5
results_gt['gt_semantic_seg'] = gt_seg
check_rotate(results_gt, results_rotated)
check_result_same(results_gt, results_rotated)
transform = dict(
type='Rotate',
@ -195,7 +142,7 @@ def test_rotate():
prob=1.)
rotate_module = build_from_cfg(transform, PIPELINES)
results_rotated = rotate_module(copy.deepcopy(results))
check_rotate(results_gt, results_rotated)
check_result_same(results_gt, results_rotated)
# test counter-clockwise roatation with angle 90,
# and specify the ratation center, PolygonMasks
@ -203,7 +150,7 @@ def test_rotate():
results_rotated = rotate_module(copy.deepcopy(results))
gt_masks = [[np.array([0, 0, 0, 0, 1, 0, 1, 0], dtype=np.float)]]
results_gt['gt_masks'] = PolygonMasks(gt_masks, 2, 4)
check_rotate(results_gt, results_rotated)
check_result_same(results_gt, results_rotated)
# test AutoAugment equipped with Rotate
policies = [[dict(type='Rotate', level=10, prob=1.)]]

@ -7,62 +7,7 @@ from mmcv.utils import build_from_cfg
from mmdet.core.mask import BitmapMasks, PolygonMasks
from mmdet.datasets.builder import PIPELINES
def construct_toy_data(poly2mask=True):
img = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.uint8)
img = np.stack([img, img, img], axis=-1)
results = dict()
# image
results['img'] = img
results['img_shape'] = img.shape
results['img_fields'] = ['img']
# bboxes
results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
results['gt_bboxes'] = np.array([[0., 0., 2., 1.]], dtype=np.float32)
results['gt_bboxes_ignore'] = np.array([[2., 0., 3., 1.]],
dtype=np.float32)
# labels
results['gt_labels'] = np.array([1], dtype=np.int64)
# masks
results['mask_fields'] = ['gt_masks']
if poly2mask:
gt_masks = np.array([[0, 1, 1, 0], [0, 1, 0, 0]],
dtype=np.uint8)[None, :, :]
results['gt_masks'] = BitmapMasks(gt_masks, 2, 4)
else:
raw_masks = [[np.array([1, 0, 2, 0, 2, 1, 1, 1], dtype=np.float)]]
results['gt_masks'] = PolygonMasks(raw_masks, 2, 4)
# segmentations
results['seg_fields'] = ['gt_semantic_seg']
results['gt_semantic_seg'] = img[..., 0]
return results
def _check_fields(results, results_sheared, keys):
for key in keys:
if isinstance(results[key], (BitmapMasks, PolygonMasks)):
assert np.equal(results[key].to_ndarray(),
results_sheared[key].to_ndarray()).all()
else:
assert np.equal(results[key], results_sheared[key]).all()
def check_shear(results, results_sheared):
# _check_keys(results, results_sheared)
# check image
_check_fields(results, results_sheared, results.get('img_fields', ['img']))
# check bboxes
_check_fields(results, results_sheared, results.get('bbox_fields', []))
# check masks
_check_fields(results, results_sheared, results.get('mask_fields', []))
# check segmentations
_check_fields(results, results_sheared, results.get('seg_fields', []))
# check gt_labels
if 'gt_labels' in results:
assert np.equal(results['gt_labels'],
results_sheared['gt_labels']).all()
from .utils import check_result_same, construct_toy_data
def test_shear():
@ -94,7 +39,7 @@ def test_shear():
direction='horizontal')
shear_module = build_from_cfg(transform, PIPELINES)
results_wo_shear = shear_module(copy.deepcopy(results))
check_shear(results, results_wo_shear)
check_result_same(results, results_wo_shear)
# test case when no shear aug (level=0, direction='vertical')
transform = dict(
@ -106,7 +51,7 @@ def test_shear():
direction='vertical')
shear_module = build_from_cfg(transform, PIPELINES)
results_wo_shear = shear_module(copy.deepcopy(results))
check_shear(results, results_wo_shear)
check_result_same(results, results_wo_shear)
# test case when no shear aug (prob<=0)
transform = dict(
@ -117,7 +62,7 @@ def test_shear():
direction='vertical')
shear_module = build_from_cfg(transform, PIPELINES)
results_wo_shear = shear_module(copy.deepcopy(results))
check_shear(results, results_wo_shear)
check_result_same(results, results_wo_shear)
# test shear horizontally, magnitude=1
transform = dict(
@ -143,14 +88,15 @@ def test_shear():
results_gt['gt_masks'] = BitmapMasks(gt_masks, 2, 4)
results_gt['gt_semantic_seg'] = np.array(
[[1, 2, 3, 4], [255, 5, 6, 7]], dtype=results['gt_semantic_seg'].dtype)
check_shear(results_gt, results_sheared)
check_result_same(results_gt, results_sheared)
# test PolygonMasks with shear horizontally, magnitude=1
results = construct_toy_data(poly2mask=False)
results_sheared = shear_module(copy.deepcopy(results))
gt_masks = [[np.array([1, 0, 2, 0, 3, 1, 2, 1], dtype=np.float)]]
print(results_sheared['gt_masks'])
gt_masks = [[np.array([0, 0, 2, 0, 3, 1, 1, 1], dtype=np.float)]]
results_gt['gt_masks'] = PolygonMasks(gt_masks, 2, 4)
check_shear(results_gt, results_sheared)
check_result_same(results_gt, results_sheared)
# test shear vertically, magnitude=-1
img_fill_val = 128
@ -180,14 +126,14 @@ def test_shear():
results_gt['gt_semantic_seg'] = np.array(
[[1, 6, 255, 255], [5, 255, 255, 255]],
dtype=results['gt_semantic_seg'].dtype)
check_shear(results_gt, results_sheared)
check_result_same(results_gt, results_sheared)
# test PolygonMasks with shear vertically, magnitude=-1
results = construct_toy_data(poly2mask=False)
results_sheared = shear_module(copy.deepcopy(results))
gt_masks = [[np.array([1, 0, 2, 0, 2, 0, 1, 0], dtype=np.float)]]
gt_masks = [[np.array([0, 0, 2, 0, 2, 0, 0, 1], dtype=np.float)]]
results_gt['gt_masks'] = PolygonMasks(gt_masks, 2, 4)
check_shear(results_gt, results_sheared)
check_result_same(results_gt, results_sheared)
results = construct_toy_data()
# same mask for BitmapMasks and PolygonMasks
@ -196,7 +142,7 @@ def test_shear():
4)
results['gt_bboxes'] = np.array([[1., 0., 2., 1.]], dtype=np.float32)
results_sheared_bitmap = shear_module(copy.deepcopy(results))
check_shear(results_sheared_bitmap, results_sheared)
check_result_same(results_sheared_bitmap, results_sheared)
# test AutoAugment equipped with Shear
policies = [[dict(type='Shear', level=10, prob=1.)]]

@ -10,6 +10,7 @@ from mmcv.utils import build_from_cfg
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
from mmdet.datasets.builder import PIPELINES
from .utils import create_random_bboxes
def test_resize():
@ -78,6 +79,7 @@ def test_resize():
results = resize_module(results)
assert np.equal(results['img'], results['img2']).all()
assert results['img_shape'] == (800, 1280, 3)
assert results['img'].dtype == results['img'].dtype == np.uint8
def test_flip():
@ -200,17 +202,10 @@ def test_random_crop():
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
def create_random_bboxes(num_bboxes, img_w, img_h):
bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2))
bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2))
bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1)
bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype(
np.int)
return bboxes
h, w, _ = img.shape
gt_bboxes = create_random_bboxes(8, w, h)
gt_bboxes_ignore = create_random_bboxes(2, w, h)
results['gt_labels'] = np.ones(gt_bboxes.shape[0], dtype=np.int64)
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_ignore'] = gt_bboxes_ignore
transform = dict(type='RandomCrop', crop_size=(h - 20, w - 20))
@ -219,6 +214,9 @@ def test_random_crop():
assert results['img'].shape[:2] == (h - 20, w - 20)
# All bboxes should be reserved after crop
assert results['img_shape'][:2] == (h - 20, w - 20)
assert results['gt_labels'].shape[0] == results['gt_bboxes'].shape[0]
assert results['gt_labels'].dtype == np.int64
assert results['gt_bboxes'].dtype == np.float32
assert results['gt_bboxes'].shape[0] == 8
assert results['gt_bboxes_ignore'].shape[0] == 2
@ -227,6 +225,8 @@ def test_random_crop():
assert (area(results['gt_bboxes']) <= area(gt_bboxes)).all()
assert (area(results['gt_bboxes_ignore']) <= area(gt_bboxes_ignore)).all()
assert results['gt_bboxes'].dtype == np.float32
assert results['gt_bboxes_ignore'].dtype == np.float32
# test assertion for invalid crop_type
with pytest.raises(ValueError):
@ -269,6 +269,8 @@ def test_random_crop():
h, w = results_transformed['img_shape'][:2]
assert int(2 * 0.3 + 0.5) <= h <= int(2 * 1 + 0.5)
assert int(4 * 0.7 + 0.5) <= w <= int(4 * 1 + 0.5)
assert results_transformed['gt_bboxes'].dtype == np.float32
assert results_transformed['gt_bboxes_ignore'].dtype == np.float32
# test crop_type "relative"
transform = dict(
@ -280,6 +282,8 @@ def test_random_crop():
results_transformed = transform_module(copy.deepcopy(results))
h, w = results_transformed['img_shape'][:2]
assert h == int(2 * 0.3 + 0.5) and w == int(4 * 0.7 + 0.5)
assert results_transformed['gt_bboxes'].dtype == np.float32
assert results_transformed['gt_bboxes_ignore'].dtype == np.float32
# test crop_type "absolute"
transform = dict(
@ -291,6 +295,8 @@ def test_random_crop():
results_transformed = transform_module(copy.deepcopy(results))
h, w = results_transformed['img_shape'][:2]
assert h == 1 and w == 2
assert results_transformed['gt_bboxes'].dtype == np.float32
assert results_transformed['gt_bboxes_ignore'].dtype == np.float32
# test crop_type "absolute_range"
transform = dict(
@ -302,18 +308,11 @@ def test_random_crop():
results_transformed = transform_module(copy.deepcopy(results))
h, w = results_transformed['img_shape'][:2]
assert 1 <= h <= 2 and 1 <= w <= 4
assert results_transformed['gt_bboxes'].dtype == np.float32
assert results_transformed['gt_bboxes_ignore'].dtype == np.float32
def test_min_iou_random_crop():
def create_random_bboxes(num_bboxes, img_w, img_h):
bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2))
bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2))
bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1)
bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype(
np.int)
return bboxes
results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../../../data/color.jpg'), 'color')
@ -328,6 +327,7 @@ def test_min_iou_random_crop():
h, w, _ = img.shape
gt_bboxes = create_random_bboxes(1, w, h)
gt_bboxes_ignore = create_random_bboxes(1, w, h)
results['gt_labels'] = np.ones(gt_bboxes.shape[0], dtype=np.int64)
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_ignore'] = gt_bboxes_ignore
transform = dict(type='MinIoURandomCrop')
@ -340,6 +340,11 @@ def test_min_iou_random_crop():
with pytest.raises(AssertionError):
crop_module(results_test)
results = crop_module(results)
assert results['gt_labels'].shape[0] == results['gt_bboxes'].shape[0]
assert results['gt_labels'].dtype == np.int64
assert results['gt_bboxes'].dtype == np.float32
assert results['gt_bboxes_ignore'].dtype == np.float32
patch = np.array([0, 0, results['img_shape'][1], results['img_shape'][0]])
ious = bbox_overlaps(patch.reshape(-1, 4),
results['gt_bboxes']).reshape(-1)
@ -555,14 +560,6 @@ def test_random_center_crop_pad():
results = load(results)
test_results = copy.deepcopy(results)
def create_random_bboxes(num_bboxes, img_w, img_h):
bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2))
bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2))
bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1)
bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype(
np.int)
return bboxes
h, w, _ = results['img_shape']
gt_bboxes = create_random_bboxes(8, w, h)
gt_bboxes_ignore = create_random_bboxes(2, w, h)
@ -585,6 +582,8 @@ def test_random_center_crop_pad():
assert train_results['pad_shape'][:2] == (h - 20, w - 20)
assert train_results['gt_bboxes'].shape[0] == 8
assert train_results['gt_bboxes_ignore'].shape[0] == 2
assert train_results['gt_bboxes'].dtype == np.float32
assert train_results['gt_bboxes_ignore'].dtype == np.float32
test_transform = dict(
type='RandomCenterCropPad',
@ -782,18 +781,10 @@ def test_random_shift():
# TODO: add img_fields test
results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
def create_random_bboxes(num_bboxes, img_w, img_h):
bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2))
bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2))
bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1)
bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype(
np.int)
return bboxes
h, w, _ = img.shape
gt_bboxes = create_random_bboxes(8, w, h)
gt_bboxes_ignore = create_random_bboxes(2, w, h)
results['gt_labels'] = torch.ones(gt_bboxes.shape[0])
results['gt_labels'] = np.ones(gt_bboxes.shape[0], dtype=np.int64)
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_ignore'] = gt_bboxes_ignore
transform = dict(type='RandomShift', shift_ratio=1.0)
@ -802,6 +793,9 @@ def test_random_shift():
assert results['img'].shape[:2] == (h, w)
assert results['gt_labels'].shape[0] == results['gt_bboxes'].shape[0]
assert results['gt_labels'].dtype == np.int64
assert results['gt_bboxes'].dtype == np.float32
assert results['gt_bboxes_ignore'].dtype == np.float32
def test_random_affine():
@ -825,18 +819,10 @@ def test_random_affine():
results['img'] = img
results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
def create_random_bboxes(num_bboxes, img_w, img_h):
bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2))
bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2))
bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1)
bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype(
np.int)
return bboxes
h, w, _ = img.shape
gt_bboxes = create_random_bboxes(8, w, h)
gt_bboxes_ignore = create_random_bboxes(2, w, h)
results['gt_labels'] = torch.ones(gt_bboxes.shape[0])
results['gt_labels'] = np.ones(gt_bboxes.shape[0], dtype=np.int64)
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_ignore'] = gt_bboxes_ignore
transform = dict(type='RandomAffine')
@ -845,10 +831,13 @@ def test_random_affine():
assert results['img'].shape[:2] == (h, w)
assert results['gt_labels'].shape[0] == results['gt_bboxes'].shape[0]
assert results['gt_labels'].dtype == np.int64
assert results['gt_bboxes'].dtype == np.float32
assert results['gt_bboxes_ignore'].dtype == np.float32
# test filter bbox
gt_bboxes = np.array([[0, 0, 1, 1], [0, 0, 3, 100]])
results['gt_labels'] = torch.ones(gt_bboxes.shape[0])
gt_bboxes = np.array([[0, 0, 1, 1], [0, 0, 3, 100]], dtype=np.float32)
results['gt_labels'] = np.ones(gt_bboxes.shape[0], dtype=np.int64)
results['gt_bboxes'] = gt_bboxes
transform = dict(
type='RandomAffine',
@ -865,6 +854,10 @@ def test_random_affine():
assert results['gt_bboxes'].shape[0] == 0
assert results['gt_labels'].shape[0] == 0
assert results['gt_labels'].shape[0] == results['gt_bboxes'].shape[0]
assert results['gt_labels'].dtype == np.int64
assert results['gt_bboxes'].dtype == np.float32
assert results['gt_bboxes_ignore'].dtype == np.float32
def test_mosaic():
@ -880,18 +873,10 @@ def test_mosaic():
# TODO: add img_fields test
results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
def create_random_bboxes(num_bboxes, img_w, img_h):
bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2))
bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2))
bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1)
bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype(
np.int)
return bboxes
h, w, _ = img.shape
gt_bboxes = create_random_bboxes(8, w, h)
gt_bboxes_ignore = create_random_bboxes(2, w, h)
results['gt_labels'] = torch.ones(gt_bboxes.shape[0])
results['gt_labels'] = np.ones(gt_bboxes.shape[0], dtype=np.int64)
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_ignore'] = gt_bboxes_ignore
transform = dict(type='Mosaic', img_scale=(10, 12))
@ -904,6 +889,10 @@ def test_mosaic():
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)
assert results['gt_labels'].shape[0] == results['gt_bboxes'].shape[0]
assert results['gt_labels'].dtype == np.int64
assert results['gt_bboxes'].dtype == np.float32
assert results['gt_bboxes_ignore'].dtype == np.float32
def test_mixup():
@ -919,18 +908,10 @@ def test_mixup():
# TODO: add img_fields test
results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
def create_random_bboxes(num_bboxes, img_w, img_h):
bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2))
bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2))
bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1)
bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype(
np.int)
return bboxes
h, w, _ = img.shape
gt_bboxes = create_random_bboxes(8, w, h)
gt_bboxes_ignore = create_random_bboxes(2, w, h)
results['gt_labels'] = torch.ones(gt_bboxes.shape[0])
results['gt_labels'] = np.ones(gt_bboxes.shape[0], dtype=np.int64)
results['gt_bboxes'] = gt_bboxes
results['gt_bboxes_ignore'] = gt_bboxes_ignore
transform = dict(type='MixUp', img_scale=(10, 12))
@ -947,3 +928,7 @@ def test_mixup():
results['mix_results'] = [copy.deepcopy(results)]
results = mixup_module(results)
assert results['img'].shape[:2] == (288, 512)
assert results['gt_labels'].shape[0] == results['gt_bboxes'].shape[0]
assert results['gt_labels'].dtype == np.int64
assert results['gt_bboxes'].dtype == np.float32
assert results['gt_bboxes_ignore'].dtype == np.float32

@ -0,0 +1,78 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmdet.core.mask import BitmapMasks, PolygonMasks
def _check_fields(results, pipeline_results, keys):
"""Check data in fields from two results are same."""
for key in keys:
if isinstance(results[key], (BitmapMasks, PolygonMasks)):
assert np.equal(results[key].to_ndarray(),
pipeline_results[key].to_ndarray()).all()
else:
assert np.equal(results[key], pipeline_results[key]).all()
assert results[key].dtype == pipeline_results[key].dtype
def check_result_same(results, pipeline_results):
"""Check whether the `pipeline_results` is the same with the predefined
`results`.
Args:
results (dict): Predefined results which should be the standard output
of the transform pipeline.
pipeline_results (dict): Results processed by the transform pipeline.
"""
# check image
_check_fields(results, pipeline_results,
results.get('img_fields', ['img']))
# check bboxes
_check_fields(results, pipeline_results, results.get('bbox_fields', []))
# check masks
_check_fields(results, pipeline_results, results.get('mask_fields', []))
# check segmentations
_check_fields(results, pipeline_results, results.get('seg_fields', []))
# check gt_labels
if 'gt_labels' in results:
assert np.equal(results['gt_labels'],
pipeline_results['gt_labels']).all()
def construct_toy_data(poly2mask=True):
img = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=np.uint8)
img = np.stack([img, img, img], axis=-1)
results = dict()
# image
results['img'] = img
results['img_shape'] = img.shape
results['img_fields'] = ['img']
# bboxes
results['bbox_fields'] = ['gt_bboxes', 'gt_bboxes_ignore']
results['gt_bboxes'] = np.array([[0., 0., 2., 1.]], dtype=np.float32)
results['gt_bboxes_ignore'] = np.array([[2., 0., 3., 1.]],
dtype=np.float32)
# labels
results['gt_labels'] = np.array([1], dtype=np.int64)
# masks
results['mask_fields'] = ['gt_masks']
if poly2mask:
gt_masks = np.array([[0, 1, 1, 0], [0, 1, 0, 0]],
dtype=np.uint8)[None, :, :]
results['gt_masks'] = BitmapMasks(gt_masks, 2, 4)
else:
raw_masks = [[np.array([0, 0, 2, 0, 2, 1, 0, 1], dtype=np.float)]]
results['gt_masks'] = PolygonMasks(raw_masks, 2, 4)
# segmentations
results['seg_fields'] = ['gt_semantic_seg']
results['gt_semantic_seg'] = img[..., 0]
return results
def create_random_bboxes(num_bboxes, img_w, img_h):
bboxes_left_top = np.random.uniform(0, 0.5, size=(num_bboxes, 2))
bboxes_right_bottom = np.random.uniform(0.5, 1, size=(num_bboxes, 2))
bboxes = np.concatenate((bboxes_left_top, bboxes_right_bottom), 1)
bboxes = (bboxes * np.array([img_w, img_h, img_w, img_h])).astype(
np.float32)
return bboxes
Loading…
Cancel
Save