diff --git a/ultralytics/data/base.py b/ultralytics/data/base.py index e0f34be946..4b488ade9a 100644 --- a/ultralytics/data/base.py +++ b/ultralytics/data/base.py @@ -100,18 +100,22 @@ class BaseDataset(Dataset): def calculate_cls_weights(self): cls = np.concatenate([l["cls"].reshape(-1) for l in self.labels]) counts = np.bincount(cls.astype(int), minlength=len(self.data["names"])) - weights = np.zeros(len(self.labels)) - # weights = np.ones(len(self.labels)) + class_weights = counts.sum() / counts + # weights = np.zeros(len(self.labels)) + weights = np.ones(len(self.labels)) for i, label in enumerate(self.labels): cls = label["cls"].reshape(-1).astype(np.int32) if len(cls) == 0: continue - weights[i] = np.mean(counts[cls]) - # set mean value of weights for background images - weights = np.where(weights == 0, weights.mean(), weights) - weights = weights.max() - weights + 1 + weights[i] = np.sum(class_weights[cls]) return weights / weights.sum() + # weights[i] = np.mean(counts[cls]) + # set mean value of weights for background images + # weights = np.where(weights == 0, weights.mean(), weights) + # weights = weights.max() - weights + 1 + # return weights / weights.sum() + def get_img_files(self, img_path): """Read image files.""" try: