Update base.py

exp-b
Laughing-q 6 months ago
parent f4114cffc9
commit be333c0797
  1. 6
      ultralytics/data/base.py

@ -102,13 +102,13 @@ class BaseDataset(Dataset):
counts = np.bincount(cls.astype(int), minlength=len(self.data["names"]))
class_weights = counts.sum() / counts
# weights = np.zeros(len(self.labels))
weights = np.ones(len(self.labels))
im_weights = np.ones(len(self.labels))
for i, label in enumerate(self.labels):
cls = label["cls"].reshape(-1).astype(np.int32)
if len(cls) == 0:
continue
weights[i] = np.sum(class_weights[cls])
return weights / weights.sum()
im_weights[i] = np.sum(class_weights[cls])
return im_weights / im_weights.sum()
# weights[i] = np.mean(counts[cls])
# set mean value of weights for background images

Loading…
Cancel
Save