|
|
|
@ -32,6 +32,7 @@ class DetectionValidator(BaseValidator): |
|
|
|
|
"""Initialize detection model with necessary variables and settings.""" |
|
|
|
|
super().__init__(dataloader, save_dir, pbar, args, _callbacks) |
|
|
|
|
self.nt_per_class = None |
|
|
|
|
self.nt_per_image = None |
|
|
|
|
self.is_coco = False |
|
|
|
|
self.is_lvis = False |
|
|
|
|
self.class_map = None |
|
|
|
@ -77,7 +78,7 @@ class DetectionValidator(BaseValidator): |
|
|
|
|
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf) |
|
|
|
|
self.seen = 0 |
|
|
|
|
self.jdict = [] |
|
|
|
|
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[]) |
|
|
|
|
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[]) |
|
|
|
|
|
|
|
|
|
def get_desc(self): |
|
|
|
|
"""Return a formatted string summarizing class metrics of YOLO model.""" |
|
|
|
@ -130,6 +131,7 @@ class DetectionValidator(BaseValidator): |
|
|
|
|
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox") |
|
|
|
|
nl = len(cls) |
|
|
|
|
stat["target_cls"] = cls |
|
|
|
|
stat["target_img"] = cls.unique() |
|
|
|
|
if npr == 0: |
|
|
|
|
if nl: |
|
|
|
|
for k in self.stats.keys(): |
|
|
|
@ -168,11 +170,11 @@ class DetectionValidator(BaseValidator): |
|
|
|
|
def get_stats(self): |
|
|
|
|
"""Returns metrics statistics and results dictionary.""" |
|
|
|
|
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy |
|
|
|
|
self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc) |
|
|
|
|
self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc) |
|
|
|
|
stats.pop("target_img", None) |
|
|
|
|
if len(stats) and stats["tp"].any(): |
|
|
|
|
self.metrics.process(**stats) |
|
|
|
|
self.nt_per_class = np.bincount( |
|
|
|
|
stats["target_cls"].astype(int), minlength=self.nc |
|
|
|
|
) # number of targets per class |
|
|
|
|
return self.metrics.results_dict |
|
|
|
|
|
|
|
|
|
def print_results(self): |
|
|
|
@ -185,7 +187,9 @@ class DetectionValidator(BaseValidator): |
|
|
|
|
# Print results per class |
|
|
|
|
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats): |
|
|
|
|
for i, c in enumerate(self.metrics.ap_class_index): |
|
|
|
|
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i))) |
|
|
|
|
LOGGER.info( |
|
|
|
|
pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i)) |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if self.args.plots: |
|
|
|
|
for normalize in True, False: |
|
|
|
|