[Feats]: support copy paste based on bbox when there is no gt mask (#8905)

* copypaste_based_on_bbox

* docformatter

* formatter again

* fix bug of dict get method in check_gt_masks

* test copypaste based on bbox

* change mask generating method

* docformatter

* docformatter

* update comment 'result' to 'results'

* rename

* Update transforms.py

* Update transforms.py

* Update transforms.py

* rename 'mask_gen' to 'using_cutmix'

* yapf

* Add files via upload

* yapf

* paste_by_box

drop parameter "using_cutmix"
add indicator "paste_by_box"
pull/9185/head
JarvisKevin 3 years ago committed by GitHub
parent a12be714e9
commit 1f8b195f72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 49
      mmdet/datasets/pipelines/transforms.py
  2. 17
      tests/test_data/test_pipelines/test_transform/test_transform.py

@ -2802,6 +2802,7 @@ class CopyPaste:
self.bbox_occluded_thr = bbox_occluded_thr
self.mask_occluded_thr = mask_occluded_thr
self.selected = selected
self.paste_by_box = False
def get_indexes(self, dataset):
"""Call function to collect indexes.s.
@ -2813,6 +2814,42 @@ class CopyPaste:
"""
return random.randint(0, len(dataset))
def gen_masks_from_bboxes(self, bboxes, img_shape):
"""Generate gt_masks based on gt_bboxes.
Args:
bboxes (list): The bboxes's list.
img_shape (tuple): The shape of image.
Returns:
BitmapMasks
"""
self.paste_by_box = True
img_h, img_w = img_shape[:2]
xmin, ymin = bboxes[:, 0:1], bboxes[:, 1:2]
xmax, ymax = bboxes[:, 2:3], bboxes[:, 3:4]
gt_masks = np.zeros((len(bboxes), img_h, img_w), dtype=np.uint8)
for i in range(len(bboxes)):
gt_masks[i,
int(ymin[i]):int(ymax[i]),
int(xmin[i]):int(xmax[i])] = 1
return BitmapMasks(gt_masks, img_h, img_w)
def get_gt_masks(self, results):
"""Get gt_masks originally or generated based on bboxes.
If gt_masks is not contained in results,
it will be generated based on gt_bboxes.
Args:
results (dict): Result dict.
Returns:
BitmapMasks: gt_masks, originally or generated based on bboxes.
"""
if results.get('gt_masks', None) is not None:
return results['gt_masks']
else:
return self.gen_masks_from_bboxes(
results.get('gt_bboxes', []), results['img'].shape)
def __call__(self, results):
"""Call function to make a copy-paste of image.
@ -2826,6 +2863,13 @@ class CopyPaste:
num_images = len(results['mix_results'])
assert num_images == 1, \
f'CopyPaste only supports processing 2 images, got {num_images}'
# Get gt_masks originally or generated based on bboxes.
results['gt_masks'] = self.get_gt_masks(results)
# only one mix picture
results['mix_results'][0]['gt_masks'] = self.get_gt_masks(
results['mix_results'][0])
if self.selected:
selected_results = self._select_object(results['mix_results'][0])
else:
@ -2871,6 +2915,8 @@ class CopyPaste:
src_masks = src_results['gt_masks']
if len(src_bboxes) == 0:
if self.paste_by_box:
dst_results.pop('gt_masks')
return dst_results
# update masks and generate bboxes from updated masks
@ -2899,6 +2945,9 @@ class CopyPaste:
dst_results['img'] = img
dst_results['gt_bboxes'] = bboxes
dst_results['gt_labels'] = labels
if self.paste_by_box:
dst_results.pop('gt_masks')
else:
dst_results['gt_masks'] = BitmapMasks(masks, masks.shape[1],
masks.shape[2])

@ -1099,3 +1099,20 @@ def test_copypaste():
src_results['gt_masks'] = src_masks[valid_inds]
results['mix_results'] = [copy.deepcopy(src_results)]
copypaste_module(results)
# test copy_paste based on bbox
dst_results.pop('gt_masks')
src_results.pop('gt_masks')
dst_bboxes = dst_results['gt_bboxes']
src_bboxes = src_results['gt_bboxes']
dst_masks = create_full_masks(dst_bboxes, w, h)
src_masks = create_full_masks(src_bboxes, w, h)
results = copy.deepcopy(dst_results)
results['mix_results'] = [copy.deepcopy(src_results)]
results = copypaste_module(results)
result_masks = create_full_masks(results['gt_bboxes'], w, h)
result_masks_np = np.where(result_masks.to_ndarray().sum(0) > 0, 1, 0)
masks_np = np.where(
(src_masks.to_ndarray().sum(0) + dst_masks.to_ndarray().sum(0)) > 0, 1,
0)
assert np.all(result_masks_np == masks_np)
assert 'gt_masks' not in results

Loading…
Cancel
Save