Resolve `albumentations` UserWarning (#13098)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/13420/head
Laughing 9 months ago committed by GitHub
parent d7bbfa42ef
commit a357fac441
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      docs/mkdocs_github_authors.yaml
  2. 77
      ultralytics/data/augment.py

@ -28,4 +28,4 @@ priytosh.revolution@live.com: priytosh-tripathi
shuizhuyuanluo@126.com: null
stormsson@users.noreply.github.com: stormsson
xinwang614@gmail.com: GreatV
andrei.kochin@intel: andrei-kochin
andrei.kochin@intel.com: andrei-kochin

@ -874,11 +874,56 @@ class Albumentations:
self.p = p
self.transform = None
prefix = colorstr("albumentations: ")
try:
import albumentations as A
check_version(A.__version__, "1.0.3", hard=True) # version requirement
# List of possible spatial transforms
spatial_transforms = {
"Affine",
"BBoxSafeRandomCrop",
"CenterCrop",
"CoarseDropout",
"Crop",
"CropAndPad",
"CropNonEmptyMaskIfExists",
"D4",
"ElasticTransform",
"Flip",
"GridDistortion",
"GridDropout",
"HorizontalFlip",
"Lambda",
"LongestMaxSize",
"MaskDropout",
"MixUp",
"Morphological",
"NoOp",
"OpticalDistortion",
"PadIfNeeded",
"Perspective",
"PiecewiseAffine",
"PixelDropout",
"RandomCrop",
"RandomCropFromBorders",
"RandomGridShuffle",
"RandomResizedCrop",
"RandomRotate90",
"RandomScale",
"RandomSizedBBoxSafeCrop",
"RandomSizedCrop",
"Resize",
"Rotate",
"SafeRotate",
"ShiftScaleRotate",
"SmallestMaxSize",
"Transpose",
"VerticalFlip",
"XYMasking",
} # from https://albumentations.ai/docs/getting_started/transforms_and_targets/#spatial-level-transforms
# Transforms
T = [
A.Blur(p=0.01),
@ -889,8 +934,14 @@ class Albumentations:
A.RandomGamma(p=0.0),
A.ImageCompression(quality_lower=75, p=0.0),
]
self.transform = A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
# Compose transforms
self.contains_spatial = any(transform.__class__.__name__ in spatial_transforms for transform in T)
self.transform = (
A.Compose(T, bbox_params=A.BboxParams(format="yolo", label_fields=["class_labels"]))
if self.contains_spatial
else A.Compose(T)
)
LOGGER.info(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
except ImportError: # package not installed, skip
pass
@ -899,20 +950,26 @@ class Albumentations:
def __call__(self, labels):
"""Generates object detections and returns a dictionary with detection results."""
im = labels["img"]
cls = labels["cls"]
if len(cls):
labels["instances"].convert_bbox("xywh")
labels["instances"].normalize(*im.shape[:2][::-1])
bboxes = labels["instances"].bboxes
# TODO: add supports of segments and keypoints
if self.transform and random.random() < self.p:
if self.transform is None or random.random() > self.p:
return labels
if self.contains_spatial:
cls = labels["cls"]
if len(cls):
im = labels["img"]
labels["instances"].convert_bbox("xywh")
labels["instances"].normalize(*im.shape[:2][::-1])
bboxes = labels["instances"].bboxes
# TODO: add supports of segments and keypoints
new = self.transform(image=im, bboxes=bboxes, class_labels=cls) # transformed
if len(new["class_labels"]) > 0: # skip update if no bbox in new im
labels["img"] = new["image"]
labels["cls"] = np.array(new["class_labels"])
bboxes = np.array(new["bboxes"], dtype=np.float32)
labels["instances"].update(bboxes=bboxes)
labels["instances"].update(bboxes=bboxes)
else:
labels["img"] = self.transform(image=labels["img"])["image"] # transformed
return labels

Loading…
Cancel
Save