Display Val images per class (#12645)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/13219/head
Adamcode 8 months ago committed by GitHub
parent b95b583237
commit 7cd871dbd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      .github/workflows/ci.yaml
  2. 14
      ultralytics/models/yolo/detect/val.py
  3. 3
      ultralytics/models/yolo/pose/val.py
  4. 3
      ultralytics/models/yolo/segment/val.py

@ -214,7 +214,6 @@ jobs:
if [[ "${{ github.event_name }}" =~ ^(schedule|workflow_dispatch)$ ]]; then
slow="pycocotools mlflow ray[tune]"
fi
slow="pycocotools mlflow ray[tune]"
pip install -e ".[export]" $torch $slow pytest-cov --extra-index-url https://download.pytorch.org/whl/cpu
- name: Check environment
run: |

@ -32,6 +32,7 @@ class DetectionValidator(BaseValidator):
"""Initialize detection model with necessary variables and settings."""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.nt_per_class = None
self.nt_per_image = None
self.is_coco = False
self.is_lvis = False
self.class_map = None
@ -77,7 +78,7 @@ class DetectionValidator(BaseValidator):
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
self.seen = 0
self.jdict = []
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[])
self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
def get_desc(self):
"""Return a formatted string summarizing class metrics of YOLO model."""
@ -130,6 +131,7 @@ class DetectionValidator(BaseValidator):
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
nl = len(cls)
stat["target_cls"] = cls
stat["target_img"] = cls.unique()
if npr == 0:
if nl:
for k in self.stats.keys():
@ -168,11 +170,11 @@ class DetectionValidator(BaseValidator):
def get_stats(self):
"""Returns metrics statistics and results dictionary."""
stats = {k: torch.cat(v, 0).cpu().numpy() for k, v in self.stats.items()} # to numpy
self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=self.nc)
self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=self.nc)
stats.pop("target_img", None)
if len(stats) and stats["tp"].any():
self.metrics.process(**stats)
self.nt_per_class = np.bincount(
stats["target_cls"].astype(int), minlength=self.nc
) # number of targets per class
return self.metrics.results_dict
def print_results(self):
@ -185,7 +187,9 @@ class DetectionValidator(BaseValidator):
# Print results per class
if self.args.verbose and not self.training and self.nc > 1 and len(self.stats):
for i, c in enumerate(self.metrics.ap_class_index):
LOGGER.info(pf % (self.names[c], self.seen, self.nt_per_class[c], *self.metrics.class_result(i)))
LOGGER.info(
pf % (self.names[c], self.nt_per_image[c], self.nt_per_class[c], *self.metrics.class_result(i))
)
if self.args.plots:
for normalize in True, False:

@ -81,7 +81,7 @@ class PoseValidator(DetectionValidator):
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0]
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[])
self.stats = dict(tp_p=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
def _prepare_batch(self, si, batch):
"""Prepares a batch for processing by converting keypoints to float and moving to device."""
@ -118,6 +118,7 @@ class PoseValidator(DetectionValidator):
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
nl = len(cls)
stat["target_cls"] = cls
stat["target_img"] = cls.unique()
if npr == 0:
if nl:
for k in self.stats.keys():

@ -51,7 +51,7 @@ class SegmentationValidator(DetectionValidator):
self.process = ops.process_mask_upsample # more accurate
else:
self.process = ops.process_mask # faster
self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[])
self.stats = dict(tp_m=[], tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
def get_desc(self):
"""Return a formatted description of evaluation metrics."""
@ -112,6 +112,7 @@ class SegmentationValidator(DetectionValidator):
cls, bbox = pbatch.pop("cls"), pbatch.pop("bbox")
nl = len(cls)
stat["target_cls"] = cls
stat["target_img"] = cls.unique()
if npr == 0:
if nl:
for k in self.stats.keys():

Loading…
Cancel
Save