|
|
@ -100,27 +100,20 @@ class BaseDataset(Dataset): |
|
|
|
def calculate_cls_weights(self): |
|
|
|
def calculate_cls_weights(self): |
|
|
|
cls = np.concatenate([l["cls"].reshape(-1) for l in self.labels]) |
|
|
|
cls = np.concatenate([l["cls"].reshape(-1) for l in self.labels]) |
|
|
|
counts = np.bincount(cls.astype(int), minlength=len(self.data["names"])) |
|
|
|
counts = np.bincount(cls.astype(int), minlength=len(self.data["names"])) |
|
|
|
rare_cls = np.nonzero(counts <= 100)[0] |
|
|
|
|
|
|
|
class_weights = counts.sum() / counts |
|
|
|
class_weights = counts.sum() / counts |
|
|
|
# weights = np.zeros(len(self.labels)) |
|
|
|
# weights = np.zeros(len(self.labels)) |
|
|
|
im_weights = np.ones(len(self.labels)) |
|
|
|
im_weights = np.ones(len(self.labels)) |
|
|
|
rare_idx = [] |
|
|
|
|
|
|
|
for i, label in enumerate(self.labels): |
|
|
|
for i, label in enumerate(self.labels): |
|
|
|
cls = label["cls"].reshape(-1).astype(np.int32) |
|
|
|
cls = label["cls"].reshape(-1).astype(np.int32) |
|
|
|
for rc in rare_cls: |
|
|
|
|
|
|
|
if rc in cls: |
|
|
|
|
|
|
|
print(label["im_file"]) |
|
|
|
|
|
|
|
rare_idx.append(i) |
|
|
|
|
|
|
|
if len(cls) == 0: |
|
|
|
if len(cls) == 0: |
|
|
|
continue |
|
|
|
continue |
|
|
|
im_weights[i] = np.sum(class_weights[cls]) |
|
|
|
im_weights[i] = np.sum(class_weights[cls]) |
|
|
|
# |
|
|
|
|
|
|
|
# import matplotlib.pyplot as plt |
|
|
|
# import matplotlib.pyplot as plt |
|
|
|
# plt.switch_backend("Agg") |
|
|
|
# plt.switch_backend("Agg") |
|
|
|
# _, ax = plt.subplots(2, 1, figsize=(21, 6), tight_layout=True) |
|
|
|
# _, ax = plt.subplots(2, 1, figsize=(21, 6), tight_layout=True) |
|
|
|
# ax = ax.ravel() |
|
|
|
# ax = ax.ravel() |
|
|
|
# ax[0].plot(im_weights / im_weights.sum()) |
|
|
|
# ax[0].plot(im_weights / im_weights.sum()) |
|
|
|
# ax[1].plot((im_weights / im_weights.sum())[rare_cls]) |
|
|
|
|
|
|
|
# plt.savefig("cls.png") |
|
|
|
# plt.savefig("cls.png") |
|
|
|
# exit() |
|
|
|
# exit() |
|
|
|
return im_weights / im_weights.sum() |
|
|
|
return im_weights / im_weights.sum() |
|
|
|