diff --git a/docs/reference/engine/model.md b/docs/reference/engine/model.md index 6ce7bc3a3..f6e468a33 100644 --- a/docs/reference/engine/model.md +++ b/docs/reference/engine/model.md @@ -3,7 +3,7 @@ description: Explore the detailed guide on using the Ultralytics YOLO Engine Mod keywords: Ultralytics, YOLO, engine model, documentation, guide, implementation, training, evaluation --- -## YOLO +## Model --- -### ::: ultralytics.engine.model.YOLO +### ::: ultralytics.engine.model.Model

diff --git a/docs/reference/models/yolo/model.md b/docs/reference/models/yolo/model.md new file mode 100644 index 000000000..5efdec8a8 --- /dev/null +++ b/docs/reference/models/yolo/model.md @@ -0,0 +1,9 @@ +--- +description: Discover the Ultralytics YOLO model class. Learn advanced techniques, tips, and tricks for training. +keywords: Ultralytics YOLO, YOLO, YOLO model, Model Training, Machine Learning, Deep Learning, Computer Vision +--- + +## YOLO +--- +### ::: ultralytics.models.yolo.model.YOLO +

diff --git a/mkdocs.yml b/mkdocs.yml index 75eaf64e6..7faccefc3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -317,6 +317,7 @@ nav: - predict: reference/models/yolo/detect/predict.md - train: reference/models/yolo/detect/train.md - val: reference/models/yolo/detect/val.md + - model: reference/models/yolo/model.md - pose: - predict: reference/models/yolo/pose/predict.md - train: reference/models/yolo/pose/train.md diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index a0aaeca32..d1648f707 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,10 +1,9 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = '8.0.142' +__version__ = '8.0.143' -from ultralytics.engine.model import YOLO from ultralytics.hub import start -from ultralytics.models import RTDETR, SAM +from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models.fastsam import FastSAM from ultralytics.models.nas import NAS from ultralytics.utils import SETTINGS as settings diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py index 9ca82f92f..29b27a54a 100644 --- a/ultralytics/engine/model.py +++ b/ultralytics/engine/model.py @@ -1,36 +1,23 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license +import inspect import sys from pathlib import Path from typing import Union from ultralytics.cfg import get_cfg from ultralytics.engine.exporter import Exporter -from ultralytics.models import yolo # noqa -from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel, - attempt_load_one_weight, guess_model_task, nn, yaml_model_load) +from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks, is_git_dir, yaml_load) from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml from ultralytics.utils.downloads import GITHUB_ASSET_STEMS from ultralytics.utils.torch_utils import smart_inference_mode -# Map head to model, trainer, validator, and predictor classes -TASK_MAP = { - 'classify': [ - ClassificationModel, yolo.classify.ClassificationTrainer, yolo.classify.ClassificationValidator, - yolo.classify.ClassificationPredictor], - 'detect': - [DetectionModel, yolo.detect.DetectionTrainer, yolo.detect.DetectionValidator, yolo.detect.DetectionPredictor], - 'segment': [ - SegmentationModel, yolo.segment.SegmentationTrainer, yolo.segment.SegmentationValidator, - yolo.segment.SegmentationPredictor], - 'pose': [PoseModel, yolo.pose.PoseTrainer, yolo.pose.PoseValidator, yolo.pose.PosePredictor]} - -class YOLO: +class Model: """ - YOLO (You Only Look Once) object detection model. + A base model class to unify apis for all the models. Args: model (str, Path): Path to the model file to load or create. @@ -81,13 +68,13 @@ class YOLO: self.predictor = None # reuse predictor self.model = None # model object self.trainer = None # trainer object - self.task = None # task type self.ckpt = None # if loaded from *.pt self.cfg = None # if loaded from *.yaml self.ckpt_path = None self.overrides = {} # overrides for trainer object self.metrics = None # validation/training metrics self.session = None # HUB session + self.task = task # task type model = str(model).strip() # strip spaces # Check if Ultralytics HUB model from https://hub.ultralytics.com @@ -109,11 +96,6 @@ class YOLO: """Calls the 'predict' function with given arguments to perform object detection.""" return self.predict(source, stream, **kwargs) - def __getattr__(self, attr): - """Raises error if object has no requested attribute.""" - name = self.__class__.__name__ - raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") - @staticmethod def is_hub_model(model): """Check if the provided model is a HUB model.""" @@ -122,19 +104,21 @@ class YOLO: [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID - def _new(self, cfg: str, task=None, verbose=True): + def _new(self, cfg: str, task=None, model=None, verbose=True): """ Initializes a new model and infers the task type from the model definitions. Args: cfg (str): model configuration file task (str | None): model task + model (BaseModel): Customized model. verbose (bool): display model info on load """ cfg_dict = yaml_model_load(cfg) self.cfg = cfg self.task = task or guess_model_task(cfg_dict) - self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model + model = model or self.smart_load('model') + self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model self.overrides['model'] = self.cfg # Below added to allow export from yamls @@ -217,7 +201,7 @@ class YOLO: self.model.fuse() @smart_inference_mode() - def predict(self, source=None, stream=False, **kwargs): + def predict(self, source=None, stream=False, predictor=None, **kwargs): """ Perform prediction using the YOLO model. @@ -225,6 +209,7 @@ class YOLO: source (str | int | PIL | np.ndarray): The source of the image to make predictions on. Accepts all source types accepted by the YOLO model. stream (bool): Whether to stream the predictions or not. Defaults to False. + predictor (BasePredictor): Customized predictor. **kwargs : Additional keyword arguments passed to the predictor. Check the 'configuration' section in the documentation for all available options. @@ -236,6 +221,8 @@ class YOLO: LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any( x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track')) + # Check prompts for SAM/FastSAM + prompts = kwargs.pop('prompts', None) overrides = self.overrides.copy() overrides['conf'] = 0.25 overrides.update(kwargs) # prefer kwargs @@ -245,12 +232,16 @@ class YOLO: overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python if not self.predictor: self.task = overrides.get('task') or self.task - self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks) + predictor = predictor or self.smart_load('predictor') + self.predictor = predictor(overrides=overrides, _callbacks=self.callbacks) self.predictor.setup_model(model=self.model, verbose=is_cli) else: # only update args if predictor is already setup self.predictor.args = get_cfg(self.predictor.args, overrides) if 'project' in overrides or 'name' in overrides: self.predictor.save_dir = self.predictor.get_save_dir() + # Set prompts for SAM/FastSAM + if len and hasattr(self.predictor, 'set_prompts'): + self.predictor.set_prompts(prompts) return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) def track(self, source=None, stream=False, persist=False, **kwargs): @@ -277,12 +268,13 @@ class YOLO: return self.predict(source=source, stream=stream, **kwargs) @smart_inference_mode() - def val(self, data=None, **kwargs): + def val(self, data=None, validator=None, **kwargs): """ Validate a model on a given dataset. Args: data (str): The dataset to validate on. Accepts all formats accepted by yolo + validator (BaseValidator): Customized validator. **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs """ overrides = self.overrides.copy() @@ -295,11 +287,12 @@ class YOLO: self.task = args.task else: args.task = self.task + validator = validator or self.smart_load('validator') if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)): args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed args.imgsz = check_imgsz(args.imgsz, max_dim=1) - validator = TASK_MAP[self.task][2](args=args, _callbacks=self.callbacks) + validator = validator(args=args, _callbacks=self.callbacks) validator(model=self.model) self.metrics = validator.metrics @@ -349,11 +342,12 @@ class YOLO: args.task = self.task return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model) - def train(self, **kwargs): + def train(self, trainer=None, **kwargs): """ Trains the model on a given dataset. Args: + trainer (BaseTrainer, optional): Customized trainer. **kwargs (Any): Any number of arguments representing the training configuration. """ self._check_is_pytorch_model() @@ -373,7 +367,8 @@ class YOLO: if overrides.get('resume'): overrides['resume'] = self.ckpt_path self.task = overrides.get('task') or self.task - self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks) + trainer = trainer or self.smart_load('trainer') + self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks) if not overrides.get('resume'): # manually set model only if not resuming self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) self.model = self.trainer.model @@ -442,3 +437,27 @@ class YOLO: """Reset all registered callbacks.""" for event in callbacks.default_callbacks.keys(): self.callbacks[event] = [callbacks.default_callbacks[event][0]] + + def __getattr__(self, attr): + """Raises error if object has no requested attribute.""" + name = self.__class__.__name__ + raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + + def smart_load(self, key): + """Load model/trainer/validator/predictor.""" + try: + return self.task_map[self.task][key] + except Exception: + name = self.__class__.__name__ + mode = inspect.stack()[1][3] # get the function name. + raise NotImplementedError( + f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.') + + @property + def task_map(self): + """Map head to model, trainer, validator, and predictor classes + + Returns: + task_map (dict) + """ + raise NotImplementedError('Please provide task map for your model!') diff --git a/ultralytics/models/__init__.py b/ultralytics/models/__init__.py index cca622266..e96f893e9 100644 --- a/ultralytics/models/__init__.py +++ b/ultralytics/models/__init__.py @@ -1,4 +1,7 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + from .rtdetr import RTDETR from .sam import SAM +from .yolo import YOLO -__all__ = 'RTDETR', 'SAM' # allow simpler import +__all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import diff --git a/ultralytics/models/fastsam/model.py b/ultralytics/models/fastsam/model.py index 96ebc30d8..6cfedc4ae 100644 --- a/ultralytics/models/fastsam/model.py +++ b/ultralytics/models/fastsam/model.py @@ -1,111 +1,31 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -""" -FastSAM model interface. -Usage - Predict: - from ultralytics import FastSAM +from pathlib import Path - model = FastSAM('last.pt') - results = model.predict('ultralytics/assets/bus.jpg') -""" - -from ultralytics.cfg import get_cfg -from ultralytics.engine.exporter import Exporter -from ultralytics.engine.model import YOLO -from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir -from ultralytics.utils.checks import check_imgsz -from ultralytics.utils.torch_utils import model_info, smart_inference_mode +from ultralytics.engine.model import Model from .predict import FastSAMPredictor +from .val import FastSAMValidator + +class FastSAM(Model): + """ + FastSAM model interface. -class FastSAM(YOLO): + Usage - Predict: + from ultralytics import FastSAM + + model = FastSAM('last.pt') + results = model.predict('ultralytics/assets/bus.jpg') + """ def __init__(self, model='FastSAM-x.pt'): """Call the __init__ method of the parent class (YOLO) with the updated default model""" if model == 'FastSAM.pt': model = 'FastSAM-x.pt' - super().__init__(model=model) - # any additional initialization code for FastSAM - - @smart_inference_mode() - def predict(self, source=None, stream=False, **kwargs): - """ - Perform prediction using the YOLO model. - - Args: - source (str | int | PIL | np.ndarray): The source of the image to make predictions on. - Accepts all source types accepted by the YOLO model. - stream (bool): Whether to stream the predictions or not. Defaults to False. - **kwargs : Additional keyword arguments passed to the predictor. - Check the 'configuration' section in the documentation for all available options. - - Returns: - (List[ultralytics.engine.results.Results]): The prediction results. - """ - if source is None: - source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' - LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") - overrides = self.overrides.copy() - overrides['conf'] = 0.25 - overrides.update(kwargs) # prefer kwargs - overrides['mode'] = kwargs.get('mode', 'predict') - assert overrides['mode'] in ['track', 'predict'] - overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python - self.predictor = FastSAMPredictor(overrides=overrides) - self.predictor.setup_model(model=self.model, verbose=False) - - return self.predictor(source, stream=stream) - - def train(self, **kwargs): - """Function trains models but raises an error as FastSAM models do not support training.""" - raise NotImplementedError("FastSAM models don't support training") - - def val(self, **kwargs): - """Run validation given dataset.""" - overrides = dict(task='segment', mode='val') - overrides.update(kwargs) # prefer kwargs - args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) - args.imgsz = check_imgsz(args.imgsz, max_dim=1) - validator = FastSAM(args=args) - validator(model=self.model) - self.metrics = validator.metrics - return validator.metrics - - @smart_inference_mode() - def export(self, **kwargs): - """ - Export model. - - Args: - **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs - """ - overrides = dict(task='detect') - overrides.update(kwargs) - overrides['mode'] = 'export' - args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) - args.task = self.task - if args.imgsz == DEFAULT_CFG.imgsz: - args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed - if args.batch == DEFAULT_CFG.batch: - args.batch = 1 # default to 1 if not modified - return Exporter(overrides=args)(model=self.model) - - def info(self, detailed=False, verbose=True): - """ - Logs model info. - - Args: - detailed (bool): Show detailed information about model. - verbose (bool): Controls verbosity. - """ - return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) - - def __call__(self, source=None, stream=False, **kwargs): - """Calls the 'predict' function with given arguments to perform object detection.""" - return self.predict(source, stream, **kwargs) + assert Path(model).suffix != '.yaml', 'FastSAM models only support pre-trained models.' + super().__init__(model=model, task='segment') - def __getattr__(self, attr): - """Raises error if object has no requested attribute.""" - name = self.__class__.__name__ - raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + @property + def task_map(self): + return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}} diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py index 1fece728a..518a7c81f 100644 --- a/ultralytics/models/nas/model.py +++ b/ultralytics/models/nas/model.py @@ -13,105 +13,36 @@ from pathlib import Path import torch -from ultralytics.cfg import get_cfg -from ultralytics.engine.exporter import Exporter -from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir -from ultralytics.utils.checks import check_imgsz +from ultralytics.engine.model import Model from ultralytics.utils.torch_utils import model_info, smart_inference_mode from .predict import NASPredictor from .val import NASValidator -class NAS: +class NAS(Model): def __init__(self, model='yolo_nas_s.pt') -> None: + assert Path(model).suffix != '.yaml', 'YOLO-NAS models only support pre-trained models.' + super().__init__(model, task='detect') + + @smart_inference_mode() + def _load(self, weights: str, task: str): # Load or create new NAS model import super_gradients - - self.predictor = None - suffix = Path(model).suffix + suffix = Path(weights).suffix if suffix == '.pt': - self._load(model) + self.model = torch.load(weights) elif suffix == '': - self.model = super_gradients.training.models.get(model, pretrained_weights='coco') - self.task = 'detect' - self.model.args = DEFAULT_CFG_DICT # attach args to model - + self.model = super_gradients.training.models.get(weights, pretrained_weights='coco') # Standardize model self.model.fuse = lambda verbose=True: self.model self.model.stride = torch.tensor([32]) self.model.names = dict(enumerate(self.model._class_names)) self.model.is_fused = lambda: False # for info() self.model.yaml = {} # for info() - self.model.pt_path = model # for export() + self.model.pt_path = weights # for export() self.model.task = 'detect' # for export() - self.info() - - @smart_inference_mode() - def _load(self, weights: str): - self.model = torch.load(weights) - - @smart_inference_mode() - def predict(self, source=None, stream=False, **kwargs): - """ - Perform prediction using the YOLO model. - - Args: - source (str | int | PIL | np.ndarray): The source of the image to make predictions on. - Accepts all source types accepted by the YOLO model. - stream (bool): Whether to stream the predictions or not. Defaults to False. - **kwargs : Additional keyword arguments passed to the predictor. - Check the 'configuration' section in the documentation for all available options. - - Returns: - (List[ultralytics.engine.results.Results]): The prediction results. - """ - if source is None: - source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' - LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") - overrides = dict(conf=0.25, task='detect', mode='predict') - overrides.update(kwargs) # prefer kwargs - if not self.predictor: - self.predictor = NASPredictor(overrides=overrides) - self.predictor.setup_model(model=self.model) - else: # only update args if predictor is already setup - self.predictor.args = get_cfg(self.predictor.args, overrides) - return self.predictor(source, stream=stream) - - def train(self, **kwargs): - """Function trains models but raises an error as NAS models do not support training.""" - raise NotImplementedError("NAS models don't support training") - - def val(self, **kwargs): - """Run validation given dataset.""" - overrides = dict(task='detect', mode='val') - overrides.update(kwargs) # prefer kwargs - args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) - args.imgsz = check_imgsz(args.imgsz, max_dim=1) - validator = NASValidator(args=args) - validator(model=self.model) - self.metrics = validator.metrics - return validator.metrics - - @smart_inference_mode() - def export(self, **kwargs): - """ - Export model. - - Args: - **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs - """ - overrides = dict(task='detect') - overrides.update(kwargs) - overrides['mode'] = 'export' - args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) - args.task = self.task - if args.imgsz == DEFAULT_CFG.imgsz: - args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed - if args.batch == DEFAULT_CFG.batch: - args.batch = 1 # default to 1 if not modified - return Exporter(overrides=args)(model=self.model) def info(self, detailed=False, verbose=True): """ @@ -123,11 +54,6 @@ class NAS: """ return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) - def __call__(self, source=None, stream=False, **kwargs): - """Calls the 'predict' function with given arguments to perform object detection.""" - return self.predict(source, stream, **kwargs) - - def __getattr__(self, attr): - """Raises error if object has no requested attribute.""" - name = self.__class__.__name__ - raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + @property + def task_map(self): + return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}} diff --git a/ultralytics/models/rtdetr/model.py b/ultralytics/models/rtdetr/model.py index 19c903eda..5612a0402 100644 --- a/ultralytics/models/rtdetr/model.py +++ b/ultralytics/models/rtdetr/model.py @@ -2,172 +2,29 @@ """ RT-DETR model interface """ - -from pathlib import Path - -import torch.nn as nn - -from ultralytics.cfg import get_cfg -from ultralytics.engine.exporter import Exporter -from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load -from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir -from ultralytics.utils.checks import check_imgsz -from ultralytics.utils.torch_utils import model_info, smart_inference_mode +from ultralytics.engine.model import Model +from ultralytics.nn.tasks import RTDETRDetectionModel from .predict import RTDETRPredictor from .train import RTDETRTrainer from .val import RTDETRValidator -class RTDETR: +class RTDETR(Model): + """ + RTDETR model interface. + """ def __init__(self, model='rtdetr-l.pt') -> None: if model and not model.endswith('.pt') and not model.endswith('.yaml'): raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.') - # Load or create new YOLO model - self.predictor = None - self.ckpt = None - suffix = Path(model).suffix - if suffix == '.yaml': - self._new(model) - else: - self._load(model) - - def _new(self, cfg: str, verbose=True): - cfg_dict = yaml_model_load(cfg) - self.cfg = cfg - self.task = 'detect' - self.model = RTDETRDetectionModel(cfg_dict, verbose=verbose) # build model - - # Below added to allow export from YAMLs - self.model.args = DEFAULT_CFG_DICT # attach args to model - self.model.task = self.task - - @smart_inference_mode() - def _load(self, weights: str): - self.model, self.ckpt = attempt_load_one_weight(weights) - self.model.args = DEFAULT_CFG_DICT # attach args to model - self.task = self.model.args['task'] - - @smart_inference_mode() - def load(self, weights='yolov8n.pt'): - """ - Transfers parameters with matching names and shapes from 'weights' to model. - """ - if isinstance(weights, (str, Path)): - weights, self.ckpt = attempt_load_one_weight(weights) - self.model.load(weights) - return self - - @smart_inference_mode() - def predict(self, source=None, stream=False, **kwargs): - """ - Perform prediction using the YOLO model. - - Args: - source (str | int | PIL | np.ndarray): The source of the image to make predictions on. - Accepts all source types accepted by the YOLO model. - stream (bool): Whether to stream the predictions or not. Defaults to False. - **kwargs : Additional keyword arguments passed to the predictor. - Check the 'configuration' section in the documentation for all available options. - - Returns: - (List[ultralytics.engine.results.Results]): The prediction results. - """ - if source is None: - source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' - LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") - overrides = dict(conf=0.25, task='detect', mode='predict') - overrides.update(kwargs) # prefer kwargs - if not self.predictor: - self.predictor = RTDETRPredictor(overrides=overrides) - self.predictor.setup_model(model=self.model) - else: # only update args if predictor is already setup - self.predictor.args = get_cfg(self.predictor.args, overrides) - return self.predictor(source, stream=stream) - - def train(self, **kwargs): - """ - Trains the model on a given dataset. - - Args: - **kwargs (Any): Any number of arguments representing the training configuration. - """ - overrides = dict(task='detect', mode='train') - overrides.update(kwargs) - overrides['deterministic'] = False - if not overrides.get('data'): - raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'") - if overrides.get('resume'): - overrides['resume'] = self.ckpt_path - self.task = overrides.get('task') or self.task - self.trainer = RTDETRTrainer(overrides=overrides) - if not overrides.get('resume'): # manually set model only if not resuming - self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) - self.model = self.trainer.model - self.trainer.train() - # Update model and cfg after training - if RANK in (-1, 0): - self.model, _ = attempt_load_one_weight(str(self.trainer.best)) - self.overrides = self.model.args - self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP - - def val(self, **kwargs): - """Run validation given dataset.""" - overrides = dict(task='detect', mode='val') - overrides.update(kwargs) # prefer kwargs - args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) - args.imgsz = check_imgsz(args.imgsz, max_dim=1) - validator = RTDETRValidator(args=args) - validator(model=self.model) - self.metrics = validator.metrics - return validator.metrics - - def info(self, verbose=True): - """Get model info""" - return model_info(self.model, verbose=verbose) - - def _check_is_pytorch_model(self): - """ - Raises TypeError is model is not a PyTorch model - """ - pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt' - pt_module = isinstance(self.model, nn.Module) - if not (pt_module or pt_str): - raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. " - f'PyTorch models can be used to train, val, predict and export, i.e. ' - f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only " - f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.") - - def fuse(self): - """Fuse PyTorch Conv2d and BatchNorm2d layers.""" - self._check_is_pytorch_model() - self.model.fuse() - - @smart_inference_mode() - def export(self, **kwargs): - """ - Export model. - - Args: - **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs - """ - overrides = dict(task='detect') - overrides.update(kwargs) - overrides['mode'] = 'export' - args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) - args.task = self.task - if args.imgsz == DEFAULT_CFG.imgsz: - args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed - if args.batch == DEFAULT_CFG.batch: - args.batch = 1 # default to 1 if not modified - return Exporter(overrides=args)(model=self.model) - - def __call__(self, source=None, stream=False, **kwargs): - """Calls the 'predict' function with given arguments to perform object detection.""" - return self.predict(source, stream, **kwargs) - - def __getattr__(self, attr): - """Raises error if object has no requested attribute.""" - name = self.__class__.__name__ - raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") + super().__init__(model=model, task='detect') + + @property + def task_map(self): + return { + 'detect': { + 'predictor': RTDETRPredictor, + 'validator': RTDETRValidator, + 'trainer': RTDETRTrainer, + 'model': RTDETRDetectionModel}} diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py index 4d7025143..b941f7bf7 100644 --- a/ultralytics/models/sam/model.py +++ b/ultralytics/models/sam/model.py @@ -3,51 +3,38 @@ SAM model interface """ -from ultralytics.cfg import get_cfg +from ultralytics.engine.model import Model from ultralytics.utils.torch_utils import model_info from .build import build_sam from .predict import Predictor -class SAM: +class SAM(Model): + """ + SAM model interface. + """ def __init__(self, model='sam_b.pt') -> None: if model and not model.endswith('.pt') and not model.endswith('.pth'): # Should raise AssertionError instead? raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint') - self.model = build_sam(model) - self.task = 'segment' # required - self.predictor = None # reuse predictor + super().__init__(model=model, task='segment') + + def _load(self, weights: str, task=None): + self.model = build_sam(weights) def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs): """Predicts and returns segmentation masks for given image or video source.""" overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024) - overrides.update(kwargs) # prefer kwargs - if not self.predictor: - self.predictor = Predictor(overrides=overrides) - self.predictor.setup_model(model=self.model) - else: # only update args if predictor is already setup - self.predictor.args = get_cfg(self.predictor.args, overrides) - return self.predictor(source, stream=stream, bboxes=bboxes, points=points, labels=labels) - - def train(self, **kwargs): - """Function trains models but raises an error as SAM models do not support training.""" - raise NotImplementedError("SAM models don't support training") - - def val(self, **kwargs): - """Run validation given dataset.""" - raise NotImplementedError("SAM models don't support validation") + kwargs.update(overrides) + prompts = dict(bboxes=bboxes, points=points, labels=labels) + super().predict(source, stream, prompts=prompts, **kwargs) def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs): """Calls the 'predict' function with given arguments to perform object detection.""" return self.predict(source, stream, bboxes, points, labels, **kwargs) - def __getattr__(self, attr): - """Raises error if object has no requested attribute.""" - name = self.__class__.__name__ - raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") - def info(self, detailed=False, verbose=True): """ Logs model info. @@ -57,3 +44,7 @@ class SAM: verbose (bool): Controls verbosity. """ return model_info(self.model, detailed=detailed, verbose=verbose) + + @property + def task_map(self): + return {'segment': {'predictor': Predictor}} diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py index 8f98d58b1..ebd072a7c 100644 --- a/ultralytics/models/sam/predict.py +++ b/ultralytics/models/sam/predict.py @@ -28,6 +28,8 @@ class Predictor(BasePredictor): # Args for set_image self.im = None self.features = None + # Args for set_prompts + self.prompts = {} # Args for segment everything self.segment_all = False @@ -92,6 +94,10 @@ class Predictor(BasePredictor): of masks and H=W=256. These low resolution logits can be passed to a subsequent iteration as mask input. """ + # Get prompts from self.prompts first + bboxes = self.prompts.pop('bboxes', bboxes) + points = self.prompts.pop('points', points) + masks = self.prompts.pop('masks', masks) if all(i is None for i in [bboxes, points, masks]): return self.generate(im, *args, **kwargs) return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output) @@ -348,6 +354,10 @@ class Predictor(BasePredictor): self.im = im break + def set_prompts(self, prompts): + """Set prompts in advance.""" + self.prompts = prompts + def reset_image(self): self.im = None self.features = None diff --git a/ultralytics/models/yolo/__init__.py b/ultralytics/models/yolo/__init__.py index a88c60b88..c66e37627 100644 --- a/ultralytics/models/yolo/__init__.py +++ b/ultralytics/models/yolo/__init__.py @@ -2,4 +2,6 @@ from ultralytics.models.yolo import classify, detect, pose, segment -__all__ = 'classify', 'segment', 'detect', 'pose' +from .model import YOLO + +__all__ = 'classify', 'segment', 'detect', 'pose', 'YOLO' diff --git a/ultralytics/models/yolo/detect/train.py b/ultralytics/models/yolo/detect/train.py index e697a059b..b32156452 100644 --- a/ultralytics/models/yolo/detect/train.py +++ b/ultralytics/models/yolo/detect/train.py @@ -1,4 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license + from copy import copy import numpy as np diff --git a/ultralytics/models/yolo/model.py b/ultralytics/models/yolo/model.py new file mode 100644 index 000000000..b85d46bdb --- /dev/null +++ b/ultralytics/models/yolo/model.py @@ -0,0 +1,36 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from ultralytics.engine.model import Model +from ultralytics.models import yolo # noqa +from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel, SegmentationModel + + +class YOLO(Model): + """ + YOLO (You Only Look Once) object detection model. + """ + + @property + def task_map(self): + """Map head to model, trainer, validator, and predictor classes""" + return { + 'classify': { + 'model': ClassificationModel, + 'trainer': yolo.classify.ClassificationTrainer, + 'validator': yolo.classify.ClassificationValidator, + 'predictor': yolo.classify.ClassificationPredictor, }, + 'detect': { + 'model': DetectionModel, + 'trainer': yolo.detect.DetectionTrainer, + 'validator': yolo.detect.DetectionValidator, + 'predictor': yolo.detect.DetectionPredictor, }, + 'segment': { + 'model': SegmentationModel, + 'trainer': yolo.segment.SegmentationTrainer, + 'validator': yolo.segment.SegmentationValidator, + 'predictor': yolo.segment.SegmentationPredictor, }, + 'pose': { + 'model': PoseModel, + 'trainer': yolo.pose.PoseTrainer, + 'validator': yolo.pose.PoseValidator, + 'predictor': yolo.pose.PosePredictor, }, } diff --git a/ultralytics/models/yolo/segment/train.py b/ultralytics/models/yolo/segment/train.py index e239c2636..89d5cb223 100644 --- a/ultralytics/models/yolo/segment/train.py +++ b/ultralytics/models/yolo/segment/train.py @@ -1,4 +1,5 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license + from copy import copy from ultralytics.models import yolo