remove cls_weight

clean-exp-bk
Laughing-q 3 months ago
parent 8e4cd40fdb
commit d5a19bf203
  1. 1
      ultralytics/cfg/default.yaml
  2. 23
      ultralytics/data/augment.py
  3. 28
      ultralytics/data/base.py

@ -116,7 +116,6 @@ mosaic: 1.0 # (float) image mosaic (probability)
mixup: 0.0 # (float) image mixup (probability)
copy_paste: 0.0 # (float) segment copy-paste (probability)
copy_paste_mode: "flip" # (str) the method to do copy_paste augmentation (flip, mixup)
cls_weight: False # (bool) Whether to sample images with cls weights when doing mosaic augmentation
auto_augment: randaugment # (str) auto augmentation policy for classification (randaugment, autoaugment, augmix)
erasing: 0.4 # (float) probability of random erasing during classification training (0-0.9), 0 means no erasing, must be less than 1.0.
crop_fraction: 1.0 # (float) image crop fraction for classification (0.1-1), 1.0 means no crop, must be greater than 0.

@ -188,7 +188,7 @@ class Mosaic(BaseMixTransform):
n (int, optional): The grid size, either 4 (for 2x2) or 9 (for 3x3).
"""
def __init__(self, dataset, imgsz=640, p=1.0, n=4, use_cls_weight=False):
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
"""Initializes the object with a dataset, image size, probability, and border."""
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
assert n in {4, 9}, "grid must be equal to 4 or 9."
@ -197,13 +197,10 @@ class Mosaic(BaseMixTransform):
self.imgsz = imgsz
self.border = (-imgsz // 2, -imgsz // 2) # width, height
self.n = n
self.use_cls_weight = use_cls_weight
def get_indexes(self, buffer=False):
def get_indexes(self, buffer=True):
"""Return a list of random indexes from the dataset."""
if self.use_cls_weight and self.dataset.cls_weights is not None:
return np.random.choice(len(self.dataset), self.n - 1, replace=False, p=self.dataset.cls_weights)
elif buffer: # select images from buffer
if buffer: # select images from buffer
return random.choices(list(self.dataset.buffer), k=self.n - 1)
else: # select any images
return [random.randint(0, len(self.dataset) - 1) for _ in range(self.n - 1)]
@ -379,10 +376,7 @@ class MixUp(BaseMixTransform):
def get_indexes(self):
"""Get a random index from the dataset."""
if self.dataset.cls_weights is not None:
return np.random.choice(len(self.dataset), 1, replace=False, p=self.dataset.cls_weights)
else:
return random.randint(0, len(self.dataset) - 1)
return random.randint(0, len(self.dataset) - 1)
def _mix_transform(self, labels):
"""Applies MixUp augmentation as per https://arxiv.org/pdf/1710.09412.pdf."""
@ -840,10 +834,7 @@ class CopyPaste(BaseMixTransform):
>>> print(index)
42
"""
if self.dataset.cls_weights is not None:
return np.random.choice(len(self.dataset), 1, replace=False, p=self.dataset.cls_weights)
else:
return random.randint(0, len(self.dataset) - 1)
return random.randint(0, len(self.dataset) - 1)
def _mix_transform(self, labels):
"""Applies CopyPaste augmentation."""
@ -1204,7 +1195,7 @@ class RandomLoadText:
def v8_transforms(dataset, imgsz, hyp, stretch=False):
"""Convert images to a size suitable for YOLOv8 training."""
mosaic = Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, use_cls_weight=False)
mosaic = Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic)
affine = RandomPerspective(
degrees=hyp.degrees,
translate=hyp.translate,
@ -1221,7 +1212,7 @@ def v8_transforms(dataset, imgsz, hyp, stretch=False):
pre_transform.append(
CopyPaste(
dataset,
pre_transform=Compose([Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic, use_cls_weight=False), affine]),
pre_transform=Compose([Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), affine]),
p=hyp.copy_paste,
)
)

@ -93,37 +93,9 @@ class BaseDataset(Dataset):
if (self.cache == "ram" and self.check_cache_ram()) or self.cache == "disk":
self.cache_images()
self.cls_weights = self.calculate_cls_weights() if hyp.cls_weight else None
# Transforms
self.transforms = self.build_transforms(hyp=hyp)
def calculate_cls_weights(self):
cls = np.concatenate([l["cls"].reshape(-1) for l in self.labels])
counts = np.bincount(cls.astype(int), minlength=len(self.data["names"]))
class_weights = counts.sum() / counts
# weights = np.zeros(len(self.labels))
im_weights = np.ones(len(self.labels))
for i, label in enumerate(self.labels):
cls = label["cls"].reshape(-1).astype(np.int32)
if len(cls) == 0:
continue
im_weights[i] = np.sum(class_weights[cls])
# import matplotlib.pyplot as plt
# plt.switch_backend("Agg")
# _, ax = plt.subplots(2, 1, figsize=(21, 6), tight_layout=True)
# ax = ax.ravel()
# ax[0].plot(im_weights / im_weights.sum())
# plt.savefig("cls.png")
# exit()
return im_weights / im_weights.sum()
# weights[i] = np.mean(counts[cls])
# set mean value of weights for background images
# weights = np.where(weights == 0, weights.mean(), weights)
# weights = weights.max() - weights + 1
# return weights / weights.sum()
def get_img_files(self, img_path):
"""Read image files."""
try:

Loading…
Cancel
Save