From be333c0797e6c5d6ee0281ac3c023b473bbb110d Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Wed, 4 Sep 2024 21:37:22 +0800 Subject: [PATCH] Update base.py --- ultralytics/data/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ultralytics/data/base.py b/ultralytics/data/base.py index 4b488ade9a..1e46820bbe 100644 --- a/ultralytics/data/base.py +++ b/ultralytics/data/base.py @@ -102,13 +102,13 @@ class BaseDataset(Dataset): counts = np.bincount(cls.astype(int), minlength=len(self.data["names"])) class_weights = counts.sum() / counts # weights = np.zeros(len(self.labels)) - weights = np.ones(len(self.labels)) + im_weights = np.ones(len(self.labels)) for i, label in enumerate(self.labels): cls = label["cls"].reshape(-1).astype(np.int32) if len(cls) == 0: continue - weights[i] = np.sum(class_weights[cls]) - return weights / weights.sum() + im_weights[i] = np.sum(class_weights[cls]) + return im_weights / im_weights.sum() # weights[i] = np.mean(counts[cls]) # set mean value of weights for background images