Update Validator to use `model` argument (#4480)

pull/4482/head
Glenn Jocher 1 year ago committed by GitHub
parent 615ddc9d97
commit b2f279ffdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      ultralytics/cfg/__init__.py
  2. 9
      ultralytics/engine/validator.py
  3. 9
      ultralytics/models/rtdetr/val.py
  4. 2
      ultralytics/models/yolo/classify/val.py
  5. 2
      ultralytics/models/yolo/detect/val.py
  6. 2
      ultralytics/models/yolo/pose/val.py
  7. 2
      ultralytics/models/yolo/segment/val.py

@ -82,7 +82,7 @@ def cfg2dict(cfg):
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object. Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
Args: Args:
cfg (str | Path | SimpleNamespace): Configuration object to be converted to a dictionary. cfg (str | Path | dict | SimpleNamespace): Configuration object to be converted to a dictionary.
Returns: Returns:
cfg (dict): Configuration object in dictionary format. cfg (dict): Configuration object in dictionary format.
@ -110,6 +110,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
# Merge overrides # Merge overrides
if overrides: if overrides:
overrides = cfg2dict(overrides) overrides = cfg2dict(overrides)
overrides.pop('save_dir', None) # special override keys to ignore
check_dict_alignment(cfg, overrides) check_dict_alignment(cfg, overrides)
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides) cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)

@ -29,7 +29,7 @@ from tqdm import tqdm
from ultralytics.cfg import get_cfg from ultralytics.cfg import get_cfg
from ultralytics.data.utils import check_cls_dataset, check_det_dataset from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis from ultralytics.utils import LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.files import increment_path from ultralytics.utils.files import increment_path
from ultralytics.utils.ops import Profile from ultralytics.utils.ops import Profile
@ -43,9 +43,9 @@ class BaseValidator:
A base class for creating validators. A base class for creating validators.
Attributes: Attributes:
args (SimpleNamespace): Configuration for the validator.
dataloader (DataLoader): Dataloader to use for validation. dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation. pbar (tqdm): Progress bar to update during validation.
args (SimpleNamespace): Configuration for the validator.
model (nn.Module): Model to validate. model (nn.Module): Model to validate.
data (dict): Data dictionary. data (dict): Data dictionary.
device (torch.device): Device to use for validation. device (torch.device): Device to use for validation.
@ -76,9 +76,9 @@ class BaseValidator:
args (SimpleNamespace): Configuration for the validator. args (SimpleNamespace): Configuration for the validator.
_callbacks (dict): Dictionary to store various callback functions. _callbacks (dict): Dictionary to store various callback functions.
""" """
self.args = get_cfg(overrides=args)
self.dataloader = dataloader self.dataloader = dataloader
self.pbar = pbar self.pbar = pbar
self.args = args or get_cfg(DEFAULT_CFG)
self.model = None self.model = None
self.data = None self.data = None
self.device = None self.device = None
@ -126,8 +126,7 @@ class BaseValidator:
else: else:
callbacks.add_integration_callbacks(self) callbacks.add_integration_callbacks(self)
self.run_callbacks('on_val_start') self.run_callbacks('on_val_start')
assert model is not None, 'Either trainer or model is needed for validation' model = AutoBackend(model or self.args.model,
model = AutoBackend(model,
device=select_device(self.args.device, self.args.batch), device=select_device(self.args.device, self.args.batch),
dnn=self.args.dnn, dnn=self.args.dnn,
data=self.args.data, data=self.args.data,

@ -14,7 +14,7 @@ from ultralytics.utils import colorstr, ops
__all__ = 'RTDETRValidator', # tuple or list __all__ = 'RTDETRValidator', # tuple or list
# TODO: Temporarily, RT-DETR does not need padding. # TODO: Temporarily RT-DETR does not need padding.
class RTDETRDataset(YOLODataset): class RTDETRDataset(YOLODataset):
def __init__(self, *args, data=None, **kwargs): def __init__(self, *args, data=None, **kwargs):
@ -47,7 +47,7 @@ class RTDETRDataset(YOLODataset):
return self.ims[i], self.im_hw0[i], self.im_hw[i] return self.ims[i], self.im_hw0[i], self.im_hw[i]
def build_transforms(self, hyp=None): def build_transforms(self, hyp=None):
"""Temporarily, only for evaluation.""" """Temporary, only for evaluation."""
if self.augment: if self.augment:
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0 hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0 hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
@ -76,12 +76,13 @@ class RTDETRValidator(DetectionValidator):
args = dict(model='rtdetr-l.pt', data='coco8.yaml') args = dict(model='rtdetr-l.pt', data='coco8.yaml')
validator = RTDETRValidator(args=args) validator = RTDETRValidator(args=args)
validator(model=args['model']) validator()
``` ```
""" """
def build_dataset(self, img_path, mode='val', batch=None): def build_dataset(self, img_path, mode='val', batch=None):
"""Build YOLO Dataset """
Build an RTDETR Dataset.
Args: Args:
img_path (str): Path to the folder containing images. img_path (str): Path to the folder containing images.

@ -22,7 +22,7 @@ class ClassificationValidator(BaseValidator):
args = dict(model='yolov8n-cls.pt', data='imagenet10') args = dict(model='yolov8n-cls.pt', data='imagenet10')
validator = ClassificationValidator(args=args) validator = ClassificationValidator(args=args)
validator(model=args['model']) validator()
``` ```
""" """

@ -25,7 +25,7 @@ class DetectionValidator(BaseValidator):
args = dict(model='yolov8n.pt', data='coco8.yaml') args = dict(model='yolov8n.pt', data='coco8.yaml')
validator = DetectionValidator(args=args) validator = DetectionValidator(args=args)
validator(model=args['model']) validator()
``` ```
""" """

@ -22,7 +22,7 @@ class PoseValidator(DetectionValidator):
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml') args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
validator = PoseValidator(args=args) validator = PoseValidator(args=args)
validator(model=args['model']) validator()
``` ```
""" """

@ -24,7 +24,7 @@ class SegmentationValidator(DetectionValidator):
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml') args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml')
validator = SegmentationValidator(args=args) validator = SegmentationValidator(args=args)
validator(model=args['model']) validator()
``` ```
""" """

Loading…
Cancel
Save