diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index 0a1b38911..c3c6a56f3 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -2802,6 +2802,7 @@ class CopyPaste: self.bbox_occluded_thr = bbox_occluded_thr self.mask_occluded_thr = mask_occluded_thr self.selected = selected + self.paste_by_box = False def get_indexes(self, dataset): """Call function to collect indexes.s. @@ -2813,6 +2814,42 @@ class CopyPaste: """ return random.randint(0, len(dataset)) + def gen_masks_from_bboxes(self, bboxes, img_shape): + """Generate gt_masks based on gt_bboxes. + + Args: + bboxes (list): The bboxes's list. + img_shape (tuple): The shape of image. + Returns: + BitmapMasks + """ + self.paste_by_box = True + img_h, img_w = img_shape[:2] + xmin, ymin = bboxes[:, 0:1], bboxes[:, 1:2] + xmax, ymax = bboxes[:, 2:3], bboxes[:, 3:4] + gt_masks = np.zeros((len(bboxes), img_h, img_w), dtype=np.uint8) + for i in range(len(bboxes)): + gt_masks[i, + int(ymin[i]):int(ymax[i]), + int(xmin[i]):int(xmax[i])] = 1 + return BitmapMasks(gt_masks, img_h, img_w) + + def get_gt_masks(self, results): + """Get gt_masks originally or generated based on bboxes. + + If gt_masks is not contained in results, + it will be generated based on gt_bboxes. + Args: + results (dict): Result dict. + Returns: + BitmapMasks: gt_masks, originally or generated based on bboxes. + """ + if results.get('gt_masks', None) is not None: + return results['gt_masks'] + else: + return self.gen_masks_from_bboxes( + results.get('gt_bboxes', []), results['img'].shape) + def __call__(self, results): """Call function to make a copy-paste of image. @@ -2826,6 +2863,13 @@ class CopyPaste: num_images = len(results['mix_results']) assert num_images == 1, \ f'CopyPaste only supports processing 2 images, got {num_images}' + + # Get gt_masks originally or generated based on bboxes. + results['gt_masks'] = self.get_gt_masks(results) + # only one mix picture + results['mix_results'][0]['gt_masks'] = self.get_gt_masks( + results['mix_results'][0]) + if self.selected: selected_results = self._select_object(results['mix_results'][0]) else: @@ -2871,6 +2915,8 @@ class CopyPaste: src_masks = src_results['gt_masks'] if len(src_bboxes) == 0: + if self.paste_by_box: + dst_results.pop('gt_masks') return dst_results # update masks and generate bboxes from updated masks @@ -2899,8 +2945,11 @@ class CopyPaste: dst_results['img'] = img dst_results['gt_bboxes'] = bboxes dst_results['gt_labels'] = labels - dst_results['gt_masks'] = BitmapMasks(masks, masks.shape[1], - masks.shape[2]) + if self.paste_by_box: + dst_results.pop('gt_masks') + else: + dst_results['gt_masks'] = BitmapMasks(masks, masks.shape[1], + masks.shape[2]) return dst_results diff --git a/tests/test_data/test_pipelines/test_transform/test_transform.py b/tests/test_data/test_pipelines/test_transform/test_transform.py index 8bc1cbb31..1ebc4f369 100644 --- a/tests/test_data/test_pipelines/test_transform/test_transform.py +++ b/tests/test_data/test_pipelines/test_transform/test_transform.py @@ -1099,3 +1099,20 @@ def test_copypaste(): src_results['gt_masks'] = src_masks[valid_inds] results['mix_results'] = [copy.deepcopy(src_results)] copypaste_module(results) + # test copy_paste based on bbox + dst_results.pop('gt_masks') + src_results.pop('gt_masks') + dst_bboxes = dst_results['gt_bboxes'] + src_bboxes = src_results['gt_bboxes'] + dst_masks = create_full_masks(dst_bboxes, w, h) + src_masks = create_full_masks(src_bboxes, w, h) + results = copy.deepcopy(dst_results) + results['mix_results'] = [copy.deepcopy(src_results)] + results = copypaste_module(results) + result_masks = create_full_masks(results['gt_bboxes'], w, h) + result_masks_np = np.where(result_masks.to_ndarray().sum(0) > 0, 1, 0) + masks_np = np.where( + (src_masks.to_ndarray().sum(0) + dst_masks.to_ndarray().sum(0)) > 0, 1, + 0) + assert np.all(result_masks_np == masks_np) + assert 'gt_masks' not in results