diff --git a/docs/reference/engine/model.md b/docs/reference/engine/model.md
index 6ce7bc3a3..f6e468a33 100644
--- a/docs/reference/engine/model.md
+++ b/docs/reference/engine/model.md
@@ -3,7 +3,7 @@ description: Explore the detailed guide on using the Ultralytics YOLO Engine Mod
keywords: Ultralytics, YOLO, engine model, documentation, guide, implementation, training, evaluation
---
-## YOLO
+## Model
---
-### ::: ultralytics.engine.model.YOLO
+### ::: ultralytics.engine.model.Model
diff --git a/docs/reference/models/yolo/model.md b/docs/reference/models/yolo/model.md
new file mode 100644
index 000000000..5efdec8a8
--- /dev/null
+++ b/docs/reference/models/yolo/model.md
@@ -0,0 +1,9 @@
+---
+description: Discover the Ultralytics YOLO model class. Learn advanced techniques, tips, and tricks for training.
+keywords: Ultralytics YOLO, YOLO, YOLO model, Model Training, Machine Learning, Deep Learning, Computer Vision
+---
+
+## YOLO
+---
+### ::: ultralytics.models.yolo.model.YOLO
+
diff --git a/mkdocs.yml b/mkdocs.yml
index 75eaf64e6..7faccefc3 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -317,6 +317,7 @@ nav:
- predict: reference/models/yolo/detect/predict.md
- train: reference/models/yolo/detect/train.md
- val: reference/models/yolo/detect/val.md
+ - model: reference/models/yolo/model.md
- pose:
- predict: reference/models/yolo/pose/predict.md
- train: reference/models/yolo/pose/train.md
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index a0aaeca32..d1648f707 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,10 +1,9 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = '8.0.142'
+__version__ = '8.0.143'
-from ultralytics.engine.model import YOLO
from ultralytics.hub import start
-from ultralytics.models import RTDETR, SAM
+from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM
from ultralytics.models.nas import NAS
from ultralytics.utils import SETTINGS as settings
diff --git a/ultralytics/engine/model.py b/ultralytics/engine/model.py
index 9ca82f92f..29b27a54a 100644
--- a/ultralytics/engine/model.py
+++ b/ultralytics/engine/model.py
@@ -1,36 +1,23 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
+import inspect
import sys
from pathlib import Path
from typing import Union
from ultralytics.cfg import get_cfg
from ultralytics.engine.exporter import Exporter
-from ultralytics.models import yolo # noqa
-from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, PoseModel, SegmentationModel,
- attempt_load_one_weight, guess_model_task, nn, yaml_model_load)
+from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
is_git_dir, yaml_load)
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.utils.torch_utils import smart_inference_mode
-# Map head to model, trainer, validator, and predictor classes
-TASK_MAP = {
- 'classify': [
- ClassificationModel, yolo.classify.ClassificationTrainer, yolo.classify.ClassificationValidator,
- yolo.classify.ClassificationPredictor],
- 'detect':
- [DetectionModel, yolo.detect.DetectionTrainer, yolo.detect.DetectionValidator, yolo.detect.DetectionPredictor],
- 'segment': [
- SegmentationModel, yolo.segment.SegmentationTrainer, yolo.segment.SegmentationValidator,
- yolo.segment.SegmentationPredictor],
- 'pose': [PoseModel, yolo.pose.PoseTrainer, yolo.pose.PoseValidator, yolo.pose.PosePredictor]}
-
-class YOLO:
+class Model:
"""
- YOLO (You Only Look Once) object detection model.
+ A base model class to unify apis for all the models.
Args:
model (str, Path): Path to the model file to load or create.
@@ -81,13 +68,13 @@ class YOLO:
self.predictor = None # reuse predictor
self.model = None # model object
self.trainer = None # trainer object
- self.task = None # task type
self.ckpt = None # if loaded from *.pt
self.cfg = None # if loaded from *.yaml
self.ckpt_path = None
self.overrides = {} # overrides for trainer object
self.metrics = None # validation/training metrics
self.session = None # HUB session
+ self.task = task # task type
model = str(model).strip() # strip spaces
# Check if Ultralytics HUB model from https://hub.ultralytics.com
@@ -109,11 +96,6 @@ class YOLO:
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
- def __getattr__(self, attr):
- """Raises error if object has no requested attribute."""
- name = self.__class__.__name__
- raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
-
@staticmethod
def is_hub_model(model):
"""Check if the provided model is a HUB model."""
@@ -122,19 +104,21 @@ class YOLO:
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
- def _new(self, cfg: str, task=None, verbose=True):
+ def _new(self, cfg: str, task=None, model=None, verbose=True):
"""
Initializes a new model and infers the task type from the model definitions.
Args:
cfg (str): model configuration file
task (str | None): model task
+ model (BaseModel): Customized model.
verbose (bool): display model info on load
"""
cfg_dict = yaml_model_load(cfg)
self.cfg = cfg
self.task = task or guess_model_task(cfg_dict)
- self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
+ model = model or self.smart_load('model')
+ self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model
self.overrides['model'] = self.cfg
# Below added to allow export from yamls
@@ -217,7 +201,7 @@ class YOLO:
self.model.fuse()
@smart_inference_mode()
- def predict(self, source=None, stream=False, **kwargs):
+ def predict(self, source=None, stream=False, predictor=None, **kwargs):
"""
Perform prediction using the YOLO model.
@@ -225,6 +209,7 @@ class YOLO:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
+ predictor (BasePredictor): Customized predictor.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
@@ -236,6 +221,8 @@ class YOLO:
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
+ # Check prompts for SAM/FastSAM
+ prompts = kwargs.pop('prompts', None)
overrides = self.overrides.copy()
overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs
@@ -245,12 +232,16 @@ class YOLO:
overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
if not self.predictor:
self.task = overrides.get('task') or self.task
- self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
+ predictor = predictor or self.smart_load('predictor')
+ self.predictor = predictor(overrides=overrides, _callbacks=self.callbacks)
self.predictor.setup_model(model=self.model, verbose=is_cli)
else: # only update args if predictor is already setup
self.predictor.args = get_cfg(self.predictor.args, overrides)
if 'project' in overrides or 'name' in overrides:
self.predictor.save_dir = self.predictor.get_save_dir()
+ # Set prompts for SAM/FastSAM
+ if len and hasattr(self.predictor, 'set_prompts'):
+ self.predictor.set_prompts(prompts)
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
def track(self, source=None, stream=False, persist=False, **kwargs):
@@ -277,12 +268,13 @@ class YOLO:
return self.predict(source=source, stream=stream, **kwargs)
@smart_inference_mode()
- def val(self, data=None, **kwargs):
+ def val(self, data=None, validator=None, **kwargs):
"""
Validate a model on a given dataset.
Args:
data (str): The dataset to validate on. Accepts all formats accepted by yolo
+ validator (BaseValidator): Customized validator.
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
overrides = self.overrides.copy()
@@ -295,11 +287,12 @@ class YOLO:
self.task = args.task
else:
args.task = self.task
+ validator = validator or self.smart_load('validator')
if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
- validator = TASK_MAP[self.task][2](args=args, _callbacks=self.callbacks)
+ validator = validator(args=args, _callbacks=self.callbacks)
validator(model=self.model)
self.metrics = validator.metrics
@@ -349,11 +342,12 @@ class YOLO:
args.task = self.task
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
- def train(self, **kwargs):
+ def train(self, trainer=None, **kwargs):
"""
Trains the model on a given dataset.
Args:
+ trainer (BaseTrainer, optional): Customized trainer.
**kwargs (Any): Any number of arguments representing the training configuration.
"""
self._check_is_pytorch_model()
@@ -373,7 +367,8 @@ class YOLO:
if overrides.get('resume'):
overrides['resume'] = self.ckpt_path
self.task = overrides.get('task') or self.task
- self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
+ trainer = trainer or self.smart_load('trainer')
+ self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks)
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
@@ -442,3 +437,27 @@ class YOLO:
"""Reset all registered callbacks."""
for event in callbacks.default_callbacks.keys():
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
+
+ def __getattr__(self, attr):
+ """Raises error if object has no requested attribute."""
+ name = self.__class__.__name__
+ raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
+
+ def smart_load(self, key):
+ """Load model/trainer/validator/predictor."""
+ try:
+ return self.task_map[self.task][key]
+ except Exception:
+ name = self.__class__.__name__
+ mode = inspect.stack()[1][3] # get the function name.
+ raise NotImplementedError(
+ f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')
+
+ @property
+ def task_map(self):
+ """Map head to model, trainer, validator, and predictor classes
+
+ Returns:
+ task_map (dict)
+ """
+ raise NotImplementedError('Please provide task map for your model!')
diff --git a/ultralytics/models/__init__.py b/ultralytics/models/__init__.py
index cca622266..e96f893e9 100644
--- a/ultralytics/models/__init__.py
+++ b/ultralytics/models/__init__.py
@@ -1,4 +1,7 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
from .rtdetr import RTDETR
from .sam import SAM
+from .yolo import YOLO
-__all__ = 'RTDETR', 'SAM' # allow simpler import
+__all__ = 'YOLO', 'RTDETR', 'SAM' # allow simpler import
diff --git a/ultralytics/models/fastsam/model.py b/ultralytics/models/fastsam/model.py
index 96ebc30d8..6cfedc4ae 100644
--- a/ultralytics/models/fastsam/model.py
+++ b/ultralytics/models/fastsam/model.py
@@ -1,111 +1,31 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-"""
-FastSAM model interface.
-Usage - Predict:
- from ultralytics import FastSAM
+from pathlib import Path
- model = FastSAM('last.pt')
- results = model.predict('ultralytics/assets/bus.jpg')
-"""
-
-from ultralytics.cfg import get_cfg
-from ultralytics.engine.exporter import Exporter
-from ultralytics.engine.model import YOLO
-from ultralytics.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir
-from ultralytics.utils.checks import check_imgsz
-from ultralytics.utils.torch_utils import model_info, smart_inference_mode
+from ultralytics.engine.model import Model
from .predict import FastSAMPredictor
+from .val import FastSAMValidator
+
+class FastSAM(Model):
+ """
+ FastSAM model interface.
-class FastSAM(YOLO):
+ Usage - Predict:
+ from ultralytics import FastSAM
+
+ model = FastSAM('last.pt')
+ results = model.predict('ultralytics/assets/bus.jpg')
+ """
def __init__(self, model='FastSAM-x.pt'):
"""Call the __init__ method of the parent class (YOLO) with the updated default model"""
if model == 'FastSAM.pt':
model = 'FastSAM-x.pt'
- super().__init__(model=model)
- # any additional initialization code for FastSAM
-
- @smart_inference_mode()
- def predict(self, source=None, stream=False, **kwargs):
- """
- Perform prediction using the YOLO model.
-
- Args:
- source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
- Accepts all source types accepted by the YOLO model.
- stream (bool): Whether to stream the predictions or not. Defaults to False.
- **kwargs : Additional keyword arguments passed to the predictor.
- Check the 'configuration' section in the documentation for all available options.
-
- Returns:
- (List[ultralytics.engine.results.Results]): The prediction results.
- """
- if source is None:
- source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
- LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
- overrides = self.overrides.copy()
- overrides['conf'] = 0.25
- overrides.update(kwargs) # prefer kwargs
- overrides['mode'] = kwargs.get('mode', 'predict')
- assert overrides['mode'] in ['track', 'predict']
- overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
- self.predictor = FastSAMPredictor(overrides=overrides)
- self.predictor.setup_model(model=self.model, verbose=False)
-
- return self.predictor(source, stream=stream)
-
- def train(self, **kwargs):
- """Function trains models but raises an error as FastSAM models do not support training."""
- raise NotImplementedError("FastSAM models don't support training")
-
- def val(self, **kwargs):
- """Run validation given dataset."""
- overrides = dict(task='segment', mode='val')
- overrides.update(kwargs) # prefer kwargs
- args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
- args.imgsz = check_imgsz(args.imgsz, max_dim=1)
- validator = FastSAM(args=args)
- validator(model=self.model)
- self.metrics = validator.metrics
- return validator.metrics
-
- @smart_inference_mode()
- def export(self, **kwargs):
- """
- Export model.
-
- Args:
- **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
- """
- overrides = dict(task='detect')
- overrides.update(kwargs)
- overrides['mode'] = 'export'
- args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
- args.task = self.task
- if args.imgsz == DEFAULT_CFG.imgsz:
- args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
- if args.batch == DEFAULT_CFG.batch:
- args.batch = 1 # default to 1 if not modified
- return Exporter(overrides=args)(model=self.model)
-
- def info(self, detailed=False, verbose=True):
- """
- Logs model info.
-
- Args:
- detailed (bool): Show detailed information about model.
- verbose (bool): Controls verbosity.
- """
- return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
-
- def __call__(self, source=None, stream=False, **kwargs):
- """Calls the 'predict' function with given arguments to perform object detection."""
- return self.predict(source, stream, **kwargs)
+ assert Path(model).suffix != '.yaml', 'FastSAM models only support pre-trained models.'
+ super().__init__(model=model, task='segment')
- def __getattr__(self, attr):
- """Raises error if object has no requested attribute."""
- name = self.__class__.__name__
- raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
+ @property
+ def task_map(self):
+ return {'segment': {'predictor': FastSAMPredictor, 'validator': FastSAMValidator}}
diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py
index 1fece728a..518a7c81f 100644
--- a/ultralytics/models/nas/model.py
+++ b/ultralytics/models/nas/model.py
@@ -13,105 +13,36 @@ from pathlib import Path
import torch
-from ultralytics.cfg import get_cfg
-from ultralytics.engine.exporter import Exporter
-from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, ROOT, is_git_dir
-from ultralytics.utils.checks import check_imgsz
+from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
from .predict import NASPredictor
from .val import NASValidator
-class NAS:
+class NAS(Model):
def __init__(self, model='yolo_nas_s.pt') -> None:
+ assert Path(model).suffix != '.yaml', 'YOLO-NAS models only support pre-trained models.'
+ super().__init__(model, task='detect')
+
+ @smart_inference_mode()
+ def _load(self, weights: str, task: str):
# Load or create new NAS model
import super_gradients
-
- self.predictor = None
- suffix = Path(model).suffix
+ suffix = Path(weights).suffix
if suffix == '.pt':
- self._load(model)
+ self.model = torch.load(weights)
elif suffix == '':
- self.model = super_gradients.training.models.get(model, pretrained_weights='coco')
- self.task = 'detect'
- self.model.args = DEFAULT_CFG_DICT # attach args to model
-
+ self.model = super_gradients.training.models.get(weights, pretrained_weights='coco')
# Standardize model
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])
self.model.names = dict(enumerate(self.model._class_names))
self.model.is_fused = lambda: False # for info()
self.model.yaml = {} # for info()
- self.model.pt_path = model # for export()
+ self.model.pt_path = weights # for export()
self.model.task = 'detect' # for export()
- self.info()
-
- @smart_inference_mode()
- def _load(self, weights: str):
- self.model = torch.load(weights)
-
- @smart_inference_mode()
- def predict(self, source=None, stream=False, **kwargs):
- """
- Perform prediction using the YOLO model.
-
- Args:
- source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
- Accepts all source types accepted by the YOLO model.
- stream (bool): Whether to stream the predictions or not. Defaults to False.
- **kwargs : Additional keyword arguments passed to the predictor.
- Check the 'configuration' section in the documentation for all available options.
-
- Returns:
- (List[ultralytics.engine.results.Results]): The prediction results.
- """
- if source is None:
- source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
- LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
- overrides = dict(conf=0.25, task='detect', mode='predict')
- overrides.update(kwargs) # prefer kwargs
- if not self.predictor:
- self.predictor = NASPredictor(overrides=overrides)
- self.predictor.setup_model(model=self.model)
- else: # only update args if predictor is already setup
- self.predictor.args = get_cfg(self.predictor.args, overrides)
- return self.predictor(source, stream=stream)
-
- def train(self, **kwargs):
- """Function trains models but raises an error as NAS models do not support training."""
- raise NotImplementedError("NAS models don't support training")
-
- def val(self, **kwargs):
- """Run validation given dataset."""
- overrides = dict(task='detect', mode='val')
- overrides.update(kwargs) # prefer kwargs
- args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
- args.imgsz = check_imgsz(args.imgsz, max_dim=1)
- validator = NASValidator(args=args)
- validator(model=self.model)
- self.metrics = validator.metrics
- return validator.metrics
-
- @smart_inference_mode()
- def export(self, **kwargs):
- """
- Export model.
-
- Args:
- **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
- """
- overrides = dict(task='detect')
- overrides.update(kwargs)
- overrides['mode'] = 'export'
- args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
- args.task = self.task
- if args.imgsz == DEFAULT_CFG.imgsz:
- args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
- if args.batch == DEFAULT_CFG.batch:
- args.batch = 1 # default to 1 if not modified
- return Exporter(overrides=args)(model=self.model)
def info(self, detailed=False, verbose=True):
"""
@@ -123,11 +54,6 @@ class NAS:
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
- def __call__(self, source=None, stream=False, **kwargs):
- """Calls the 'predict' function with given arguments to perform object detection."""
- return self.predict(source, stream, **kwargs)
-
- def __getattr__(self, attr):
- """Raises error if object has no requested attribute."""
- name = self.__class__.__name__
- raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
+ @property
+ def task_map(self):
+ return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}}
diff --git a/ultralytics/models/rtdetr/model.py b/ultralytics/models/rtdetr/model.py
index 19c903eda..5612a0402 100644
--- a/ultralytics/models/rtdetr/model.py
+++ b/ultralytics/models/rtdetr/model.py
@@ -2,172 +2,29 @@
"""
RT-DETR model interface
"""
-
-from pathlib import Path
-
-import torch.nn as nn
-
-from ultralytics.cfg import get_cfg
-from ultralytics.engine.exporter import Exporter
-from ultralytics.nn.tasks import RTDETRDetectionModel, attempt_load_one_weight, yaml_model_load
-from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, RANK, ROOT, is_git_dir
-from ultralytics.utils.checks import check_imgsz
-from ultralytics.utils.torch_utils import model_info, smart_inference_mode
+from ultralytics.engine.model import Model
+from ultralytics.nn.tasks import RTDETRDetectionModel
from .predict import RTDETRPredictor
from .train import RTDETRTrainer
from .val import RTDETRValidator
-class RTDETR:
+class RTDETR(Model):
+ """
+ RTDETR model interface.
+ """
def __init__(self, model='rtdetr-l.pt') -> None:
if model and not model.endswith('.pt') and not model.endswith('.yaml'):
raise NotImplementedError('RT-DETR only supports creating from pt file or yaml file.')
- # Load or create new YOLO model
- self.predictor = None
- self.ckpt = None
- suffix = Path(model).suffix
- if suffix == '.yaml':
- self._new(model)
- else:
- self._load(model)
-
- def _new(self, cfg: str, verbose=True):
- cfg_dict = yaml_model_load(cfg)
- self.cfg = cfg
- self.task = 'detect'
- self.model = RTDETRDetectionModel(cfg_dict, verbose=verbose) # build model
-
- # Below added to allow export from YAMLs
- self.model.args = DEFAULT_CFG_DICT # attach args to model
- self.model.task = self.task
-
- @smart_inference_mode()
- def _load(self, weights: str):
- self.model, self.ckpt = attempt_load_one_weight(weights)
- self.model.args = DEFAULT_CFG_DICT # attach args to model
- self.task = self.model.args['task']
-
- @smart_inference_mode()
- def load(self, weights='yolov8n.pt'):
- """
- Transfers parameters with matching names and shapes from 'weights' to model.
- """
- if isinstance(weights, (str, Path)):
- weights, self.ckpt = attempt_load_one_weight(weights)
- self.model.load(weights)
- return self
-
- @smart_inference_mode()
- def predict(self, source=None, stream=False, **kwargs):
- """
- Perform prediction using the YOLO model.
-
- Args:
- source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
- Accepts all source types accepted by the YOLO model.
- stream (bool): Whether to stream the predictions or not. Defaults to False.
- **kwargs : Additional keyword arguments passed to the predictor.
- Check the 'configuration' section in the documentation for all available options.
-
- Returns:
- (List[ultralytics.engine.results.Results]): The prediction results.
- """
- if source is None:
- source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
- LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
- overrides = dict(conf=0.25, task='detect', mode='predict')
- overrides.update(kwargs) # prefer kwargs
- if not self.predictor:
- self.predictor = RTDETRPredictor(overrides=overrides)
- self.predictor.setup_model(model=self.model)
- else: # only update args if predictor is already setup
- self.predictor.args = get_cfg(self.predictor.args, overrides)
- return self.predictor(source, stream=stream)
-
- def train(self, **kwargs):
- """
- Trains the model on a given dataset.
-
- Args:
- **kwargs (Any): Any number of arguments representing the training configuration.
- """
- overrides = dict(task='detect', mode='train')
- overrides.update(kwargs)
- overrides['deterministic'] = False
- if not overrides.get('data'):
- raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
- if overrides.get('resume'):
- overrides['resume'] = self.ckpt_path
- self.task = overrides.get('task') or self.task
- self.trainer = RTDETRTrainer(overrides=overrides)
- if not overrides.get('resume'): # manually set model only if not resuming
- self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
- self.model = self.trainer.model
- self.trainer.train()
- # Update model and cfg after training
- if RANK in (-1, 0):
- self.model, _ = attempt_load_one_weight(str(self.trainer.best))
- self.overrides = self.model.args
- self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
-
- def val(self, **kwargs):
- """Run validation given dataset."""
- overrides = dict(task='detect', mode='val')
- overrides.update(kwargs) # prefer kwargs
- args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
- args.imgsz = check_imgsz(args.imgsz, max_dim=1)
- validator = RTDETRValidator(args=args)
- validator(model=self.model)
- self.metrics = validator.metrics
- return validator.metrics
-
- def info(self, verbose=True):
- """Get model info"""
- return model_info(self.model, verbose=verbose)
-
- def _check_is_pytorch_model(self):
- """
- Raises TypeError is model is not a PyTorch model
- """
- pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
- pt_module = isinstance(self.model, nn.Module)
- if not (pt_module or pt_str):
- raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
- f'PyTorch models can be used to train, val, predict and export, i.e. '
- f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
- f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
-
- def fuse(self):
- """Fuse PyTorch Conv2d and BatchNorm2d layers."""
- self._check_is_pytorch_model()
- self.model.fuse()
-
- @smart_inference_mode()
- def export(self, **kwargs):
- """
- Export model.
-
- Args:
- **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
- """
- overrides = dict(task='detect')
- overrides.update(kwargs)
- overrides['mode'] = 'export'
- args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
- args.task = self.task
- if args.imgsz == DEFAULT_CFG.imgsz:
- args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
- if args.batch == DEFAULT_CFG.batch:
- args.batch = 1 # default to 1 if not modified
- return Exporter(overrides=args)(model=self.model)
-
- def __call__(self, source=None, stream=False, **kwargs):
- """Calls the 'predict' function with given arguments to perform object detection."""
- return self.predict(source, stream, **kwargs)
-
- def __getattr__(self, attr):
- """Raises error if object has no requested attribute."""
- name = self.__class__.__name__
- raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
+ super().__init__(model=model, task='detect')
+
+ @property
+ def task_map(self):
+ return {
+ 'detect': {
+ 'predictor': RTDETRPredictor,
+ 'validator': RTDETRValidator,
+ 'trainer': RTDETRTrainer,
+ 'model': RTDETRDetectionModel}}
diff --git a/ultralytics/models/sam/model.py b/ultralytics/models/sam/model.py
index 4d7025143..b941f7bf7 100644
--- a/ultralytics/models/sam/model.py
+++ b/ultralytics/models/sam/model.py
@@ -3,51 +3,38 @@
SAM model interface
"""
-from ultralytics.cfg import get_cfg
+from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info
from .build import build_sam
from .predict import Predictor
-class SAM:
+class SAM(Model):
+ """
+ SAM model interface.
+ """
def __init__(self, model='sam_b.pt') -> None:
if model and not model.endswith('.pt') and not model.endswith('.pth'):
# Should raise AssertionError instead?
raise NotImplementedError('Segment anything prediction requires pre-trained checkpoint')
- self.model = build_sam(model)
- self.task = 'segment' # required
- self.predictor = None # reuse predictor
+ super().__init__(model=model, task='segment')
+
+ def _load(self, weights: str, task=None):
+ self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Predicts and returns segmentation masks for given image or video source."""
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
- overrides.update(kwargs) # prefer kwargs
- if not self.predictor:
- self.predictor = Predictor(overrides=overrides)
- self.predictor.setup_model(model=self.model)
- else: # only update args if predictor is already setup
- self.predictor.args = get_cfg(self.predictor.args, overrides)
- return self.predictor(source, stream=stream, bboxes=bboxes, points=points, labels=labels)
-
- def train(self, **kwargs):
- """Function trains models but raises an error as SAM models do not support training."""
- raise NotImplementedError("SAM models don't support training")
-
- def val(self, **kwargs):
- """Run validation given dataset."""
- raise NotImplementedError("SAM models don't support validation")
+ kwargs.update(overrides)
+ prompts = dict(bboxes=bboxes, points=points, labels=labels)
+ super().predict(source, stream, prompts=prompts, **kwargs)
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, bboxes, points, labels, **kwargs)
- def __getattr__(self, attr):
- """Raises error if object has no requested attribute."""
- name = self.__class__.__name__
- raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
-
def info(self, detailed=False, verbose=True):
"""
Logs model info.
@@ -57,3 +44,7 @@ class SAM:
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose)
+
+ @property
+ def task_map(self):
+ return {'segment': {'predictor': Predictor}}
diff --git a/ultralytics/models/sam/predict.py b/ultralytics/models/sam/predict.py
index 8f98d58b1..ebd072a7c 100644
--- a/ultralytics/models/sam/predict.py
+++ b/ultralytics/models/sam/predict.py
@@ -28,6 +28,8 @@ class Predictor(BasePredictor):
# Args for set_image
self.im = None
self.features = None
+ # Args for set_prompts
+ self.prompts = {}
# Args for segment everything
self.segment_all = False
@@ -92,6 +94,10 @@ class Predictor(BasePredictor):
of masks and H=W=256. These low resolution logits can be passed to
a subsequent iteration as mask input.
"""
+ # Get prompts from self.prompts first
+ bboxes = self.prompts.pop('bboxes', bboxes)
+ points = self.prompts.pop('points', points)
+ masks = self.prompts.pop('masks', masks)
if all(i is None for i in [bboxes, points, masks]):
return self.generate(im, *args, **kwargs)
return self.prompt_inference(im, bboxes, points, labels, masks, multimask_output)
@@ -348,6 +354,10 @@ class Predictor(BasePredictor):
self.im = im
break
+ def set_prompts(self, prompts):
+ """Set prompts in advance."""
+ self.prompts = prompts
+
def reset_image(self):
self.im = None
self.features = None
diff --git a/ultralytics/models/yolo/__init__.py b/ultralytics/models/yolo/__init__.py
index a88c60b88..c66e37627 100644
--- a/ultralytics/models/yolo/__init__.py
+++ b/ultralytics/models/yolo/__init__.py
@@ -2,4 +2,6 @@
from ultralytics.models.yolo import classify, detect, pose, segment
-__all__ = 'classify', 'segment', 'detect', 'pose'
+from .model import YOLO
+
+__all__ = 'classify', 'segment', 'detect', 'pose', 'YOLO'
diff --git a/ultralytics/models/yolo/detect/train.py b/ultralytics/models/yolo/detect/train.py
index e697a059b..b32156452 100644
--- a/ultralytics/models/yolo/detect/train.py
+++ b/ultralytics/models/yolo/detect/train.py
@@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
+
from copy import copy
import numpy as np
diff --git a/ultralytics/models/yolo/model.py b/ultralytics/models/yolo/model.py
new file mode 100644
index 000000000..b85d46bdb
--- /dev/null
+++ b/ultralytics/models/yolo/model.py
@@ -0,0 +1,36 @@
+# Ultralytics YOLO 🚀, AGPL-3.0 license
+
+from ultralytics.engine.model import Model
+from ultralytics.models import yolo # noqa
+from ultralytics.nn.tasks import ClassificationModel, DetectionModel, PoseModel, SegmentationModel
+
+
+class YOLO(Model):
+ """
+ YOLO (You Only Look Once) object detection model.
+ """
+
+ @property
+ def task_map(self):
+ """Map head to model, trainer, validator, and predictor classes"""
+ return {
+ 'classify': {
+ 'model': ClassificationModel,
+ 'trainer': yolo.classify.ClassificationTrainer,
+ 'validator': yolo.classify.ClassificationValidator,
+ 'predictor': yolo.classify.ClassificationPredictor, },
+ 'detect': {
+ 'model': DetectionModel,
+ 'trainer': yolo.detect.DetectionTrainer,
+ 'validator': yolo.detect.DetectionValidator,
+ 'predictor': yolo.detect.DetectionPredictor, },
+ 'segment': {
+ 'model': SegmentationModel,
+ 'trainer': yolo.segment.SegmentationTrainer,
+ 'validator': yolo.segment.SegmentationValidator,
+ 'predictor': yolo.segment.SegmentationPredictor, },
+ 'pose': {
+ 'model': PoseModel,
+ 'trainer': yolo.pose.PoseTrainer,
+ 'validator': yolo.pose.PoseValidator,
+ 'predictor': yolo.pose.PosePredictor, }, }
diff --git a/ultralytics/models/yolo/segment/train.py b/ultralytics/models/yolo/segment/train.py
index e239c2636..89d5cb223 100644
--- a/ultralytics/models/yolo/segment/train.py
+++ b/ultralytics/models/yolo/segment/train.py
@@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
+
from copy import copy
from ultralytics.models import yolo