diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py index 2e0c7dc0e2..b9bddc7596 100644 --- a/ultralytics/data/augment.py +++ b/ultralytics/data/augment.py @@ -200,11 +200,10 @@ class Mosaic(BaseMixTransform): def get_indexes(self, buffer=True): """Return a list of random indexes from the dataset.""" - return np.random.choice(len(self.dataset), self.n - 1, replace=False, p=self.dataset.cls_weights) - # if buffer: # select images from buffer - # return random.choices(list(self.dataset.buffer), k=self.n - 1) - # else: # select any images - # return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)] + if buffer: # select images from buffer + return random.choices(list(self.dataset.buffer), k=self.n - 1) + else: # select any images + return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)] def _mix_transform(self, labels): """Apply mixup transformation to the input image and labels.""" @@ -795,7 +794,106 @@ class LetterBox: return labels -class CopyPaste: +class CopyPaste(BaseMixTransform): + """ + Implements Copy-Paste augmentation as described in https://arxiv.org/abs/2012.07177. + + This class applies Copy-Paste augmentation on images and their corresponding instances. + + Attributes: + dataset: The dataset on which the copypaste augmentation is applied. + pre_transform: The pre-transforms for the mixed labels. + p (float): Probability of applying the Copy-Paste augmentation. Must be between 0 and 1. + + Methods: + __call__: Applies Copy-Paste augmentation to given image and instances. + + Examples: + >>> copypaste = CopyPaste(dataset, p=0.5) + >>> augmented_labels = copypaste(labels) + >>> augmented_image = augmented_labels['img'] + """ + + def __init__(self, dataset, pre_transform=None, p=0.5) -> None: + """Initializes CopyPaste object with dataset, pre_transform, and probability of applying MixUp.""" + super().__init__(dataset=dataset, pre_transform=pre_transform, p=p) + + def get_indexes(self): + """ + Get a random index from the dataset. + + This method returns a single random index from the dataset, which is used to select an image for MixUp + augmentation. + + Returns: + (int): A random integer index within the range of the dataset length. + + Examples: + >>> copypaste = CopyPaste(dataset) + >>> index = copypaste.get_indexes() + >>> print(index) + 42 + """ + return random.randint(0, len(self.dataset) - 1) + + def _mix_transform(self, labels): + """Applies CopyPaste augmentation.""" + labels2 = labels["mix_labels"][0] + im = labels["img"] + cls = labels["cls"] + h, w = im.shape[:2] + instances = labels.pop("instances") + instances.convert_bbox(format="xyxy") + instances.denormalize(w, h) + + im_new = np.zeros(im.shape, np.uint8) + instances2 = labels2.pop("instances") + ioa = bbox_ioa(instances2.bboxes, instances.bboxes) # intersection over area, (N, M) + indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, ) + n = len(indexes) + # for j in random.sample(list(indexes), k=round(self.p * n)): + sorted_idx = np.argsort(ioa.max(1)[indexes]) + indexes = indexes[sorted_idx] + for j in indexes[: round(self.p * n)]: + cls = np.concatenate((cls, labels2["cls"][[j]]), axis=0) + instances = Instances.concatenate((instances, instances2[[j]]), axis=0) + cv2.drawContours(im_new, instances2.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED) + + result = labels2["img"] # augment segments + i = im_new.astype(bool) + im[i] = result[i] + + labels["img"] = im + labels["cls"] = cls + labels["instances"] = instances + return labels + + def __call__(self, labels): + """Applies pre-processing transforms and copy_paste transforms to labels data.""" + if len(labels["instances"].segments) == 0 or self.p == 0: + return labels + # Get index of one or three other images + indexes = self.get_indexes() + if isinstance(indexes, int): + indexes = [indexes] + + # Get images information will be used for Mosaic or MixUp + mix_labels = [self.dataset.get_image_and_label(i) for i in indexes] + + if self.pre_transform is not None: + for i, data in enumerate(mix_labels): + mix_labels[i] = self.pre_transform(data) + labels["mix_labels"] = mix_labels + + # Update cls and texts + labels = self._update_label_text(labels) + # Mosaic or MixUp + labels = self._mix_transform(labels) + labels.pop("mix_labels", None) + return labels + + +class FlipCopyPaste: """ Implements the Copy-Paste augmentation as described in the paper https://arxiv.org/abs/2012.07177. This class is responsible for applying the Copy-Paste augmentation on images and their corresponding instances. @@ -849,7 +947,7 @@ class CopyPaste: # for j in random.sample(list(indexes), k=round(self.p * n)): sorted_idx = np.argsort(ioa.max(1)[indexes]) indexes = indexes[sorted_idx] - for j in indexes[:round(self.p * n)]: + for j in indexes[: round(self.p * n)]: cls = np.concatenate((cls, cls[[j]]), axis=0) instances = Instances.concatenate((instances, ins_flip[[j]]), axis=0) cv2.drawContours(im_new, instances.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED) @@ -1097,20 +1195,21 @@ class RandomLoadText: def v8_transforms(dataset, imgsz, hyp, stretch=False): """Convert images to a size suitable for YOLOv8 training.""" - pre_transform = Compose( - [ - Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), - CopyPaste(p=hyp.copy_paste), - RandomPerspective( - degrees=hyp.degrees, - translate=hyp.translate, - scale=hyp.scale, - shear=hyp.shear, - perspective=hyp.perspective, - pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)), - ), - ] + mosaic = Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic) + affine = RandomPerspective( + degrees=hyp.degrees, + translate=hyp.translate, + scale=hyp.scale, + shear=hyp.shear, + perspective=hyp.perspective, + pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)), ) + + pre_transform = Compose([mosaic, affine]) + if hyp.copy_paste_mode == "flip": + pre_transform.insert(1, FlipCopyPaste(p=hyp.copy_paste)) + else: + pre_transform.append(CopyPaste(dataset, pre_transform=Compose([mosaic, affine]), p=hyp.copy_paste)) flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation if dataset.use_keypoints: kpt_shape = dataset.data.get("kpt_shape", None)