diff --git a/ultralytics/data/base.py b/ultralytics/data/base.py index 4b488ade9a..1e46820bbe 100644 --- a/ultralytics/data/base.py +++ b/ultralytics/data/base.py @@ -102,13 +102,13 @@ class BaseDataset(Dataset): counts = np.bincount(cls.astype(int), minlength=len(self.data["names"])) class_weights = counts.sum() / counts # weights = np.zeros(len(self.labels)) - weights = np.ones(len(self.labels)) + im_weights = np.ones(len(self.labels)) for i, label in enumerate(self.labels): cls = label["cls"].reshape(-1).astype(np.int32) if len(cls) == 0: continue - weights[i] = np.sum(class_weights[cls]) - return weights / weights.sum() + im_weights[i] = np.sum(class_weights[cls]) + return im_weights / im_weights.sum() # weights[i] = np.mean(counts[cls]) # set mean value of weights for background images