Update base.py

exp-b
Laughing-q 6 months ago
parent 727b14fb63
commit 1d80cb2417
  1. 9
      ultralytics/data/base.py

@ -100,27 +100,20 @@ class BaseDataset(Dataset):
def calculate_cls_weights(self):
cls = np.concatenate([l["cls"].reshape(-1) for l in self.labels])
counts = np.bincount(cls.astype(int), minlength=len(self.data["names"]))
rare_cls = np.nonzero(counts <= 100)[0]
class_weights = counts.sum() / counts
# weights = np.zeros(len(self.labels))
im_weights = np.ones(len(self.labels))
rare_idx = []
for i, label in enumerate(self.labels):
cls = label["cls"].reshape(-1).astype(np.int32)
for rc in rare_cls:
if rc in cls:
print(label["im_file"])
rare_idx.append(i)
if len(cls) == 0:
continue
im_weights[i] = np.sum(class_weights[cls])
#
# import matplotlib.pyplot as plt
# plt.switch_backend("Agg")
# _, ax = plt.subplots(2, 1, figsize=(21, 6), tight_layout=True)
# ax = ax.ravel()
# ax[0].plot(im_weights / im_weights.sum())
# ax[1].plot((im_weights / im_weights.sum())[rare_cls])
# plt.savefig("cls.png")
# exit()
return im_weights / im_weights.sum()

Loading…
Cancel
Save