From eaa4dd975761837efebd1307c3f5e2b8671b50cb Mon Sep 17 00:00:00 2001 From: Francesco Mattioli Date: Tue, 29 Oct 2024 13:58:12 +0100 Subject: [PATCH] new param albumentations_p --- ultralytics/cfg/default.yaml | 1 + ultralytics/data/augment.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/ultralytics/cfg/default.yaml b/ultralytics/cfg/default.yaml index 7922f63592..182f661232 100644 --- a/ultralytics/cfg/default.yaml +++ b/ultralytics/cfg/default.yaml @@ -119,6 +119,7 @@ copy_paste_mode: "flip" # (str) the method to do copy_paste augmentation (flip, auto_augment: randaugment # (str) auto augmentation policy for classification (randaugment, autoaugment, augmix) erasing: 0.4 # (float) probability of random erasing during classification training (0-0.9), 0 means no erasing, must be less than 1.0. crop_fraction: 1.0 # (float) image crop fraction for classification (0.1-1), 1.0 means no crop, must be greater than 0. +albumentations_p: 0.1 # (float) probability of albumentations augmentation during training (0.0-1.0), 0 means no albumentations, must be less than 1.0. # Custom config.yaml --------------------------------------------------------------------------------------------------- cfg: # (str, optional) for overriding defaults.yaml diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py index 49bdc92235..088184bcfd 100644 --- a/ultralytics/data/augment.py +++ b/ultralytics/data/augment.py @@ -1841,13 +1841,13 @@ class Albumentations: # Transforms T = [ - A.Blur(p=0.01), - A.MedianBlur(p=0.01), - A.ToGray(p=0.01), - A.CLAHE(p=0.01), - A.RandomBrightnessContrast(p=0.0), - A.RandomGamma(p=0.0), - A.ImageCompression(quality_lower=75, p=0.0), + A.Blur(p=self.p), + A.MedianBlur(p=self.p), + A.ToGray(p=self.p), + A.CLAHE(p=self.p), + A.RandomBrightnessContrast(p=self.p), + A.RandomGamma(p=self.p), + A.ImageCompression(quality_lower=75, p=self.p), ] # Compose transforms @@ -2328,7 +2328,7 @@ def v8_transforms(dataset, imgsz, hyp, stretch=False): [ pre_transform, MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), - Albumentations(p=1.0), + Albumentations(p=hyp.albumentations_p), RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v), RandomFlip(direction="vertical", p=hyp.flipud), RandomFlip(direction="horizontal", p=hyp.fliplr, flip_idx=flip_idx),