|
|
|
@ -2802,6 +2802,7 @@ class CopyPaste: |
|
|
|
|
self.bbox_occluded_thr = bbox_occluded_thr |
|
|
|
|
self.mask_occluded_thr = mask_occluded_thr |
|
|
|
|
self.selected = selected |
|
|
|
|
self.paste_by_box = False |
|
|
|
|
|
|
|
|
|
def get_indexes(self, dataset): |
|
|
|
|
"""Call function to collect indexes.s. |
|
|
|
@ -2813,6 +2814,42 @@ class CopyPaste: |
|
|
|
|
""" |
|
|
|
|
return random.randint(0, len(dataset)) |
|
|
|
|
|
|
|
|
|
def gen_masks_from_bboxes(self, bboxes, img_shape): |
|
|
|
|
"""Generate gt_masks based on gt_bboxes. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
bboxes (list): The bboxes's list. |
|
|
|
|
img_shape (tuple): The shape of image. |
|
|
|
|
Returns: |
|
|
|
|
BitmapMasks |
|
|
|
|
""" |
|
|
|
|
self.paste_by_box = True |
|
|
|
|
img_h, img_w = img_shape[:2] |
|
|
|
|
xmin, ymin = bboxes[:, 0:1], bboxes[:, 1:2] |
|
|
|
|
xmax, ymax = bboxes[:, 2:3], bboxes[:, 3:4] |
|
|
|
|
gt_masks = np.zeros((len(bboxes), img_h, img_w), dtype=np.uint8) |
|
|
|
|
for i in range(len(bboxes)): |
|
|
|
|
gt_masks[i, |
|
|
|
|
int(ymin[i]):int(ymax[i]), |
|
|
|
|
int(xmin[i]):int(xmax[i])] = 1 |
|
|
|
|
return BitmapMasks(gt_masks, img_h, img_w) |
|
|
|
|
|
|
|
|
|
def get_gt_masks(self, results): |
|
|
|
|
"""Get gt_masks originally or generated based on bboxes. |
|
|
|
|
|
|
|
|
|
If gt_masks is not contained in results, |
|
|
|
|
it will be generated based on gt_bboxes. |
|
|
|
|
Args: |
|
|
|
|
results (dict): Result dict. |
|
|
|
|
Returns: |
|
|
|
|
BitmapMasks: gt_masks, originally or generated based on bboxes. |
|
|
|
|
""" |
|
|
|
|
if results.get('gt_masks', None) is not None: |
|
|
|
|
return results['gt_masks'] |
|
|
|
|
else: |
|
|
|
|
return self.gen_masks_from_bboxes( |
|
|
|
|
results.get('gt_bboxes', []), results['img'].shape) |
|
|
|
|
|
|
|
|
|
def __call__(self, results): |
|
|
|
|
"""Call function to make a copy-paste of image. |
|
|
|
|
|
|
|
|
@ -2826,6 +2863,13 @@ class CopyPaste: |
|
|
|
|
num_images = len(results['mix_results']) |
|
|
|
|
assert num_images == 1, \ |
|
|
|
|
f'CopyPaste only supports processing 2 images, got {num_images}' |
|
|
|
|
|
|
|
|
|
# Get gt_masks originally or generated based on bboxes. |
|
|
|
|
results['gt_masks'] = self.get_gt_masks(results) |
|
|
|
|
# only one mix picture |
|
|
|
|
results['mix_results'][0]['gt_masks'] = self.get_gt_masks( |
|
|
|
|
results['mix_results'][0]) |
|
|
|
|
|
|
|
|
|
if self.selected: |
|
|
|
|
selected_results = self._select_object(results['mix_results'][0]) |
|
|
|
|
else: |
|
|
|
@ -2871,6 +2915,8 @@ class CopyPaste: |
|
|
|
|
src_masks = src_results['gt_masks'] |
|
|
|
|
|
|
|
|
|
if len(src_bboxes) == 0: |
|
|
|
|
if self.paste_by_box: |
|
|
|
|
dst_results.pop('gt_masks') |
|
|
|
|
return dst_results |
|
|
|
|
|
|
|
|
|
# update masks and generate bboxes from updated masks |
|
|
|
@ -2899,6 +2945,9 @@ class CopyPaste: |
|
|
|
|
dst_results['img'] = img |
|
|
|
|
dst_results['gt_bboxes'] = bboxes |
|
|
|
|
dst_results['gt_labels'] = labels |
|
|
|
|
if self.paste_by_box: |
|
|
|
|
dst_results.pop('gt_masks') |
|
|
|
|
else: |
|
|
|
|
dst_results['gt_masks'] = BitmapMasks(masks, masks.shape[1], |
|
|
|
|
masks.shape[2]) |
|
|
|
|
|
|
|
|
|