|
|
|
@ -1,4 +1,3 @@ |
|
|
|
|
import collections |
|
|
|
|
import math |
|
|
|
|
import random |
|
|
|
|
from copy import deepcopy |
|
|
|
@ -65,7 +64,8 @@ class Compose: |
|
|
|
|
class BaseMixTransform: |
|
|
|
|
"""This implementation is from mmyolo""" |
|
|
|
|
|
|
|
|
|
def __init__(self, pre_transform=None, p=0.0) -> None: |
|
|
|
|
def __init__(self, dataset, pre_transform=None, p=0.0) -> None: |
|
|
|
|
self.dataset = dataset |
|
|
|
|
self.pre_transform = pre_transform |
|
|
|
|
self.p = p |
|
|
|
|
|
|
|
|
@ -73,41 +73,28 @@ class BaseMixTransform: |
|
|
|
|
if random.uniform(0, 1) > self.p: |
|
|
|
|
return labels |
|
|
|
|
|
|
|
|
|
assert "dataset" in labels |
|
|
|
|
dataset = labels.pop("dataset") |
|
|
|
|
|
|
|
|
|
# get index of one or three other images |
|
|
|
|
indexes = self.get_indexes(dataset) |
|
|
|
|
if not isinstance(indexes, collections.abc.Sequence): |
|
|
|
|
indexes = self.get_indexes() |
|
|
|
|
if isinstance(indexes, int): |
|
|
|
|
indexes = [indexes] |
|
|
|
|
|
|
|
|
|
# get images information will be used for Mosaic or MixUp |
|
|
|
|
mix_labels = [dataset.get_label_info(index) for index in indexes] |
|
|
|
|
mix_labels = [self.dataset.get_label_info(i) for i in indexes] |
|
|
|
|
|
|
|
|
|
if self.pre_transform is not None: |
|
|
|
|
for i, data in enumerate(mix_labels): |
|
|
|
|
# pre_transform may also require dataset |
|
|
|
|
data.update({"dataset": dataset}) |
|
|
|
|
# before Mosaic or MixUp need to go through |
|
|
|
|
# the necessary pre_transform |
|
|
|
|
_labels = self.pre_transform(data) |
|
|
|
|
_labels.pop("dataset") |
|
|
|
|
mix_labels[i] = _labels |
|
|
|
|
mix_labels[i] = self.pre_transform(data) |
|
|
|
|
labels["mix_labels"] = mix_labels |
|
|
|
|
|
|
|
|
|
# Mosaic or MixUp |
|
|
|
|
labels = self._mix_transform(labels) |
|
|
|
|
|
|
|
|
|
if "mix_labels" in labels: |
|
|
|
|
labels.pop("mix_labels") |
|
|
|
|
labels["dataset"] = dataset |
|
|
|
|
|
|
|
|
|
labels.pop("mix_labels", None) |
|
|
|
|
return labels |
|
|
|
|
|
|
|
|
|
def _mix_transform(self, labels): |
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
def get_indexes(self, dataset): |
|
|
|
|
def get_indexes(self): |
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -119,14 +106,15 @@ class Mosaic(BaseMixTransform): |
|
|
|
|
Default to (640, 640). |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, imgsz=640, p=1.0, border=(0, 0)): |
|
|
|
|
def __init__(self, dataset, imgsz=640, p=1.0, border=(0, 0)): |
|
|
|
|
assert 0 <= p <= 1.0, "The probability should be in range [0, 1]. " f"got {p}." |
|
|
|
|
super().__init__(pre_transform=None, p=p) |
|
|
|
|
super().__init__(dataset=dataset, p=p) |
|
|
|
|
self.dataset = dataset |
|
|
|
|
self.imgsz = imgsz |
|
|
|
|
self.border = border |
|
|
|
|
|
|
|
|
|
def get_indexes(self, dataset): |
|
|
|
|
return [random.randint(0, len(dataset) - 1) for _ in range(3)] |
|
|
|
|
def get_indexes(self): |
|
|
|
|
return [random.randint(0, len(self.dataset) - 1) for _ in range(3)] |
|
|
|
|
|
|
|
|
|
def _mix_transform(self, labels): |
|
|
|
|
mosaic_labels = [] |
|
|
|
@ -193,25 +181,19 @@ class Mosaic(BaseMixTransform): |
|
|
|
|
|
|
|
|
|
class MixUp(BaseMixTransform): |
|
|
|
|
|
|
|
|
|
def __init__(self, pre_transform=None, p=0.0) -> None: |
|
|
|
|
super().__init__(pre_transform=pre_transform, p=p) |
|
|
|
|
def __init__(self, dataset, pre_transform=None, p=0.0) -> None: |
|
|
|
|
super().__init__(dataset=dataset, pre_transform=pre_transform, p=p) |
|
|
|
|
|
|
|
|
|
def get_indexes(self, dataset): |
|
|
|
|
return random.randint(0, len(dataset) - 1) |
|
|
|
|
def get_indexes(self): |
|
|
|
|
return random.randint(0, len(self.dataset) - 1) |
|
|
|
|
|
|
|
|
|
def _mix_transform(self, labels): |
|
|
|
|
im = labels["img"] |
|
|
|
|
labels2 = labels["mix_labels"][0] |
|
|
|
|
im2 = labels2["img"] |
|
|
|
|
cls2 = labels2["cls"] |
|
|
|
|
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf |
|
|
|
|
r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 |
|
|
|
|
im = (im * r + im2 * (1 - r)).astype(np.uint8) |
|
|
|
|
cat_instances = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0) |
|
|
|
|
cls = labels["cls"] |
|
|
|
|
labels["img"] = im |
|
|
|
|
labels["instances"] = cat_instances |
|
|
|
|
labels["cls"] = np.concatenate([cls, cls2], 0) |
|
|
|
|
labels2 = labels["mix_labels"][0] |
|
|
|
|
labels["img"] = (labels["img"] * r + labels2["img"] * (1 - r)).astype(np.uint8) |
|
|
|
|
labels["instances"] = Instances.concatenate([labels["instances"], labels2["instances"]], axis=0) |
|
|
|
|
labels["cls"] = np.concatenate([labels["cls"], labels2["cls"]], 0) |
|
|
|
|
return labels |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -412,7 +394,6 @@ class RandomHSV: |
|
|
|
|
|
|
|
|
|
im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) |
|
|
|
|
cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed |
|
|
|
|
labels["img"] = img |
|
|
|
|
return labels |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -606,7 +587,6 @@ class Format: |
|
|
|
|
self.batch_idx = batch_idx # keep the batch indexes |
|
|
|
|
|
|
|
|
|
def __call__(self, labels): |
|
|
|
|
labels.pop("dataset", None) |
|
|
|
|
img = labels["img"] |
|
|
|
|
h, w = img.shape[:2] |
|
|
|
|
cls = labels.pop("cls") |
|
|
|
@ -656,9 +636,9 @@ class Format: |
|
|
|
|
return masks, instances, cls |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def mosaic_transforms(imgsz, hyp): |
|
|
|
|
def mosaic_transforms(dataset, imgsz, hyp): |
|
|
|
|
pre_transform = Compose([ |
|
|
|
|
Mosaic(imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]), |
|
|
|
|
Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, border=[-imgsz // 2, -imgsz // 2]), |
|
|
|
|
CopyPaste(p=hyp.copy_paste), |
|
|
|
|
RandomPerspective( |
|
|
|
|
degrees=hyp.degrees, |
|
|
|
@ -670,7 +650,7 @@ def mosaic_transforms(imgsz, hyp): |
|
|
|
|
),]) |
|
|
|
|
return Compose([ |
|
|
|
|
pre_transform, |
|
|
|
|
MixUp(pre_transform=pre_transform, p=hyp.mixup), |
|
|
|
|
MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), |
|
|
|
|
Albumentations(p=1.0), |
|
|
|
|
RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), |
|
|
|
|
RandomFlip(direction="vertical", p=hyp.flipud), |
|
|
|
|