diff --git a/ultralytics/data/base.py b/ultralytics/data/base.py index c787f8ff5b..6b82f040a8 100644 --- a/ultralytics/data/base.py +++ b/ultralytics/data/base.py @@ -100,27 +100,20 @@ class BaseDataset(Dataset): def calculate_cls_weights(self): cls = np.concatenate([l["cls"].reshape(-1) for l in self.labels]) counts = np.bincount(cls.astype(int), minlength=len(self.data["names"])) - rare_cls = np.nonzero(counts <= 100)[0] class_weights = counts.sum() / counts # weights = np.zeros(len(self.labels)) im_weights = np.ones(len(self.labels)) - rare_idx = [] for i, label in enumerate(self.labels): cls = label["cls"].reshape(-1).astype(np.int32) - for rc in rare_cls: - if rc in cls: - print(label["im_file"]) - rare_idx.append(i) if len(cls) == 0: continue im_weights[i] = np.sum(class_weights[cls]) - # + # import matplotlib.pyplot as plt # plt.switch_backend("Agg") # _, ax = plt.subplots(2, 1, figsize=(21, 6), tight_layout=True) # ax = ax.ravel() # ax[0].plot(im_weights / im_weights.sum()) - # ax[1].plot((im_weights / im_weights.sum())[rare_cls]) # plt.savefig("cls.png") # exit() return im_weights / im_weights.sum()