Update base.py

exp-b
Laughing-q 6 months ago
parent a1d4a7591f
commit f4114cffc9
  1. 16
      ultralytics/data/base.py

@ -100,18 +100,22 @@ class BaseDataset(Dataset):
def calculate_cls_weights(self):
cls = np.concatenate([l["cls"].reshape(-1) for l in self.labels])
counts = np.bincount(cls.astype(int), minlength=len(self.data["names"]))
weights = np.zeros(len(self.labels))
# weights = np.ones(len(self.labels))
class_weights = counts.sum() / counts
# weights = np.zeros(len(self.labels))
weights = np.ones(len(self.labels))
for i, label in enumerate(self.labels):
cls = label["cls"].reshape(-1).astype(np.int32)
if len(cls) == 0:
continue
weights[i] = np.mean(counts[cls])
# set mean value of weights for background images
weights = np.where(weights == 0, weights.mean(), weights)
weights = weights.max() - weights + 1
weights[i] = np.sum(class_weights[cls])
return weights / weights.sum()
# weights[i] = np.mean(counts[cls])
# set mean value of weights for background images
# weights = np.where(weights == 0, weights.mean(), weights)
# weights = weights.max() - weights + 1
# return weights / weights.sum()
def get_img_files(self, img_path):
"""Read image files."""
try:

Loading…
Cancel
Save