Update base.py

exp-b
Laughing-q 6 months ago
parent f4114cffc9
commit be333c0797
  1. 6
      ultralytics/data/base.py

@ -102,13 +102,13 @@ class BaseDataset(Dataset):
counts = np.bincount(cls.astype(int), minlength=len(self.data["names"])) counts = np.bincount(cls.astype(int), minlength=len(self.data["names"]))
class_weights = counts.sum() / counts class_weights = counts.sum() / counts
# weights = np.zeros(len(self.labels)) # weights = np.zeros(len(self.labels))
weights = np.ones(len(self.labels)) im_weights = np.ones(len(self.labels))
for i, label in enumerate(self.labels): for i, label in enumerate(self.labels):
cls = label["cls"].reshape(-1).astype(np.int32) cls = label["cls"].reshape(-1).astype(np.int32)
if len(cls) == 0: if len(cls) == 0:
continue continue
weights[i] = np.sum(class_weights[cls]) im_weights[i] = np.sum(class_weights[cls])
return weights / weights.sum() return im_weights / im_weights.sum()
# weights[i] = np.mean(counts[cls]) # weights[i] = np.mean(counts[cls])
# set mean value of weights for background images # set mean value of weights for background images

Loading…
Cancel
Save