add use_cls_weight

exp-b
Laughing-q 6 months ago
parent 9550fe0949
commit 4ecc65ac7b
  1. 17
      ultralytics/data/augment.py

@ -188,7 +188,7 @@ class Mosaic(BaseMixTransform):
n (int, optional): The grid size, either 4 (for 2x2) or 9 (for 3x3).
"""
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
def __init__(self, dataset, imgsz=640, p=1.0, n=4, use_cls_weight=False):
"""Initializes the object with a dataset, image size, probability, and border."""
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
assert n in {4, 9}, "grid must be equal to 4 or 9."
@ -197,10 +197,11 @@ class Mosaic(BaseMixTransform):
self.imgsz = imgsz
self.border = (-imgsz // 2, -imgsz // 2) # width, height
self.n = n
self.use_cls_weight = use_cls_weight
def get_indexes(self, buffer=True):
def get_indexes(self, buffer=False):
"""Return a list of random indexes from the dataset."""
if self.dataset.cls_weights is not None:
if self.use_cls_weight and self.dataset.cls_weights is not None:
return np.random.choice(len(self.dataset), self.n - 1, replace=False, p=self.dataset.cls_weights)
elif buffer: # select images from buffer
return random.choices(list(self.dataset.buffer), k=self.n - 1)
@ -1197,7 +1198,7 @@ class RandomLoadText:
def v8_transforms(dataset, imgsz, hyp, stretch=False):
"""Convert images to a size suitable for YOLOv8 training."""
mosaic = Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic)
mosaic = Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, use_cls_weight=False)
affine = RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,
@ -1211,7 +1212,13 @@ def v8_transforms(dataset, imgsz, hyp, stretch=False):
if hyp.copy_paste_mode == "flip":
pre_transform.insert(1, FlipCopyPaste(p=hyp.copy_paste))
else:
pre_transform.append(CopyPaste(dataset, pre_transform=Compose([mosaic, affine]), p=hyp.copy_paste))
pre_transform.append(
CopyPaste(
dataset,
pre_transform=Compose([Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, use_cls_weight=True), affine]),
p=hyp.copy_paste,
)
)
flip_idx = dataset.data.get("flip_idx", []) # for keypoints augmentation
if dataset.use_keypoints:
kpt_shape = dataset.data.get("kpt_shape", None)

Loading…
Cancel
Save