Update augment.py

exp-b
Laughing-q 3 months ago
parent 9129d8a397
commit 659fdb6422
  1. 117
      ultralytics/data/augment.py

@ -794,7 +794,106 @@ class LetterBox:
return labels
class CopyPaste:
class CopyPaste(BaseMixTransform):
"""
Implements Copy-Paste augmentation as described in https://arxiv.org/abs/2012.07177.
This class applies Copy-Paste augmentation on images and their corresponding instances.
Attributes:
dataset: The dataset on which the copypaste augmentation is applied.
pre_transform: The pre-transforms for the mixed labels.
p (float): Probability of applying the Copy-Paste augmentation. Must be between 0 and 1.
Methods:
__call__: Applies Copy-Paste augmentation to given image and instances.
Examples:
>>> copypaste = CopyPaste(dataset, p=0.5)
>>> augmented_labels = copypaste(labels)
>>> augmented_image = augmented_labels['img']
"""
def __init__(self, dataset, pre_transform=None, p=0.5) -> None:
"""Initializes CopyPaste object with dataset, pre_transform, and probability of applying MixUp."""
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p)
def get_indexes(self):
"""
Get a random index from the dataset.
This method returns a single random index from the dataset, which is used to select an image for MixUp
augmentation.
Returns:
(int): A random integer index within the range of the dataset length.
Examples:
>>> copypaste = CopyPaste(dataset)
>>> index = copypaste.get_indexes()
>>> print(index)
42
"""
return random.randint(0, len(self.dataset) - 1)
def _mix_transform(self, labels):
"""Applies CopyPaste augmentation."""
labels2 = labels["mix_labels"][0]
im = labels["img"]
cls = labels["cls"]
h, w = im.shape[:2]
instances = labels.pop("instances")
instances.convert_bbox(format="xyxy")
instances.denormalize(w, h)
im_new = np.zeros(im.shape, np.uint8)
instances2 = labels2.pop("instances")
ioa = bbox_ioa(instances2.bboxes, instances.bboxes) # intersection over area, (N, M)
indexes = np.nonzero((ioa < 0.30).all(1))[0] # (N, )
n = len(indexes)
# for j in random.sample(list(indexes), k=round(self.p * n)):
sorted_idx = np.argsort(ioa.max(1)[indexes])
indexes = indexes[sorted_idx]
for j in indexes[: round(self.p * n)]:
cls = np.concatenate((cls, labels2["cls"][[j]]), axis=0)
instances = Instances.concatenate((instances, instances2[[j]]), axis=0)
cv2.drawContours(im_new, instances2.segments[[j]].astype(np.int32), -1, (1, 1, 1), cv2.FILLED)
result = labels2["img"] # augment segments
i = im_new.astype(bool)
im[i] = result[i]
labels["img"] = im
labels["cls"] = cls
labels["instances"] = instances
return labels
def __call__(self, labels):
"""Applies pre-processing transforms and copy_paste transforms to labels data."""
if len(labels["instances"].segments) == 0 or self.p == 0:
return labels
# Get index of one or three other images
indexes = self.get_indexes()
if isinstance(indexes, int):
indexes = [indexes]
# Get images information will be used for Mosaic or MixUp
mix_labels = [self.dataset.get_image_and_label(i) for i in indexes]
if self.pre_transform is not None:
for i, data in enumerate(mix_labels):
mix_labels[i] = self.pre_transform(data)
labels["mix_labels"] = mix_labels
# Update cls and texts
labels = self._update_label_text(labels)
# Mosaic or MixUp
labels = self._mix_transform(labels)
labels.pop("mix_labels", None)
return labels
class OldCopyPaste:
"""
Implements the Copy-Paste augmentation as described in the paper https://arxiv.org/abs/2012.07177. This class is
responsible for applying the Copy-Paste augmentation on images and their corresponding instances.
@ -1096,18 +1195,22 @@ class RandomLoadText:
def v8_transforms(dataset, imgsz, hyp, stretch=False):
"""Convert images to a size suitable for YOLOv8 training."""
pre_transform = Compose(
[
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
CopyPaste(p=hyp.copy_paste),
RandomPerspective(
mosaic = Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic)
affine = RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,
scale=hyp.scale,
shear=hyp.shear,
perspective=hyp.perspective,
pre_transform=None if stretch else LetterBox(new_shape=(imgsz, imgsz)),
),
)
pre_transform = Compose(
[
mosaic,
# CopyPaste(dataset, pre_transform=mosaic, p=hyp.copy_paste),
# OldCopyPaste(p=hyp.copy_paste),
affine,
CopyPaste(dataset, pre_transform=Compose([mosaic, affine]), p=hyp.copy_paste),
]
)
flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation

Loading…
Cancel
Save