|
|
|
@ -8,15 +8,14 @@ from typing import Union |
|
|
|
|
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir |
|
|
|
|
from ultralytics.hub.utils import HUB_WEB_ROOT |
|
|
|
|
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load |
|
|
|
|
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, emojis, yaml_load |
|
|
|
|
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, emojis, yaml_load |
|
|
|
|
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml |
|
|
|
|
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS |
|
|
|
|
from ultralytics.utils.torch_utils import smart_inference_mode |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Model: |
|
|
|
|
class Model(nn.Module): |
|
|
|
|
""" |
|
|
|
|
A base model class to unify apis for all the models. |
|
|
|
|
A base class to unify APIs for all models. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
model (str, Path): Path to the model file to load or create. |
|
|
|
@ -63,6 +62,7 @@ class Model: |
|
|
|
|
model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'. |
|
|
|
|
task (Any, optional): Task type for the YOLO model. Defaults to None. |
|
|
|
|
""" |
|
|
|
|
super().__init__() |
|
|
|
|
self.callbacks = callbacks.get_default_callbacks() |
|
|
|
|
self.predictor = None # reuse predictor |
|
|
|
|
self.model = None # model object |
|
|
|
@ -116,13 +116,12 @@ class Model: |
|
|
|
|
cfg_dict = yaml_model_load(cfg) |
|
|
|
|
self.cfg = cfg |
|
|
|
|
self.task = task or guess_model_task(cfg_dict) |
|
|
|
|
self.model = (model or self.smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model |
|
|
|
|
self.model = (model or self._smart_load('model'))(cfg_dict, verbose=verbose and RANK == -1) # build model |
|
|
|
|
self.overrides['model'] = self.cfg |
|
|
|
|
self.overrides['task'] = self.task |
|
|
|
|
|
|
|
|
|
# Below added to allow export from YAMLs |
|
|
|
|
args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args |
|
|
|
|
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model |
|
|
|
|
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args) |
|
|
|
|
self.model.task = self.task |
|
|
|
|
|
|
|
|
|
def _load(self, weights: str, task=None): |
|
|
|
@ -154,12 +153,13 @@ class Model: |
|
|
|
|
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt' |
|
|
|
|
pt_module = isinstance(self.model, nn.Module) |
|
|
|
|
if not (pt_module or pt_str): |
|
|
|
|
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. " |
|
|
|
|
f'PyTorch models can be used to train, val, predict and export, i.e. ' |
|
|
|
|
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only " |
|
|
|
|
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.") |
|
|
|
|
raise TypeError( |
|
|
|
|
f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. " |
|
|
|
|
f"PyTorch models can train, val, predict and export, i.e. 'model.train(data=...)', but exported " |
|
|
|
|
f"formats like ONNX, TensorRT etc. only support 'predict' and 'val' modes, " |
|
|
|
|
f"i.e. 'yolo predict model=yolov8n.onnx'.\nTo run CUDA or MPS inference please pass the device " |
|
|
|
|
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'") |
|
|
|
|
|
|
|
|
|
@smart_inference_mode() |
|
|
|
|
def reset_weights(self): |
|
|
|
|
""" |
|
|
|
|
Resets the model modules parameters to randomly initialized values, losing all training information. |
|
|
|
@ -172,7 +172,6 @@ class Model: |
|
|
|
|
p.requires_grad = True |
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
@smart_inference_mode() |
|
|
|
|
def load(self, weights='yolov8n.pt'): |
|
|
|
|
""" |
|
|
|
|
Transfers parameters with matching names and shapes from 'weights' to model. |
|
|
|
@ -199,7 +198,6 @@ class Model: |
|
|
|
|
self._check_is_pytorch_model() |
|
|
|
|
self.model.fuse() |
|
|
|
|
|
|
|
|
|
@smart_inference_mode() |
|
|
|
|
def predict(self, source=None, stream=False, predictor=None, **kwargs): |
|
|
|
|
""" |
|
|
|
|
Perform prediction using the YOLO model. |
|
|
|
@ -227,7 +225,7 @@ class Model: |
|
|
|
|
prompts = args.pop('prompts', None) # for SAM-type models |
|
|
|
|
|
|
|
|
|
if not self.predictor: |
|
|
|
|
self.predictor = (predictor or self.smart_load('predictor'))(overrides=args, _callbacks=self.callbacks) |
|
|
|
|
self.predictor = (predictor or self._smart_load('predictor'))(overrides=args, _callbacks=self.callbacks) |
|
|
|
|
self.predictor.setup_model(model=self.model, verbose=is_cli) |
|
|
|
|
else: # only update args if predictor is already setup |
|
|
|
|
self.predictor.args = get_cfg(self.predictor.args, args) |
|
|
|
@ -258,7 +256,6 @@ class Model: |
|
|
|
|
kwargs['mode'] = 'track' |
|
|
|
|
return self.predict(source=source, stream=stream, **kwargs) |
|
|
|
|
|
|
|
|
|
@smart_inference_mode() |
|
|
|
|
def val(self, validator=None, **kwargs): |
|
|
|
|
""" |
|
|
|
|
Validate a model on a given dataset. |
|
|
|
@ -271,12 +268,11 @@ class Model: |
|
|
|
|
args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right |
|
|
|
|
args['imgsz'] = check_imgsz(args['imgsz'], max_dim=1) |
|
|
|
|
|
|
|
|
|
validator = (validator or self.smart_load('validator'))(args=args, _callbacks=self.callbacks) |
|
|
|
|
validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks) |
|
|
|
|
validator(model=self.model) |
|
|
|
|
self.metrics = validator.metrics |
|
|
|
|
return validator.metrics |
|
|
|
|
|
|
|
|
|
@smart_inference_mode() |
|
|
|
|
def benchmark(self, **kwargs): |
|
|
|
|
""" |
|
|
|
|
Benchmark a model on all export formats. |
|
|
|
@ -333,7 +329,7 @@ class Model: |
|
|
|
|
if args.get('resume'): |
|
|
|
|
args['resume'] = self.ckpt_path |
|
|
|
|
|
|
|
|
|
self.trainer = (trainer or self.smart_load('trainer'))(overrides=args, _callbacks=self.callbacks) |
|
|
|
|
self.trainer = (trainer or self._smart_load('trainer'))(overrides=args, _callbacks=self.callbacks) |
|
|
|
|
if not args.get('resume'): # manually set model only if not resuming |
|
|
|
|
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml) |
|
|
|
|
self.model = self.trainer.model |
|
|
|
@ -365,15 +361,12 @@ class Model: |
|
|
|
|
args = {**self.overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right |
|
|
|
|
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations) |
|
|
|
|
|
|
|
|
|
def to(self, device): |
|
|
|
|
""" |
|
|
|
|
Sends the model to the given device. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
device (str): device |
|
|
|
|
""" |
|
|
|
|
def _apply(self, fn): |
|
|
|
|
"""Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers.""" |
|
|
|
|
self._check_is_pytorch_model() |
|
|
|
|
self.model.to(device) |
|
|
|
|
self = super()._apply(fn) # noqa |
|
|
|
|
self.predictor = None # reset predictor as device may have changed |
|
|
|
|
self.overrides['device'] = str(self.device) # i.e. device(type='cuda', index=0) -> 'cuda:0' |
|
|
|
|
return self |
|
|
|
|
|
|
|
|
|
@property |
|
|
|
@ -410,12 +403,12 @@ class Model: |
|
|
|
|
for event in callbacks.default_callbacks.keys(): |
|
|
|
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]] |
|
|
|
|
|
|
|
|
|
def __getattr__(self, attr): |
|
|
|
|
"""Raises error if object has no requested attribute.""" |
|
|
|
|
name = self.__class__.__name__ |
|
|
|
|
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") |
|
|
|
|
# def __getattr__(self, attr): |
|
|
|
|
# """Raises error if object has no requested attribute.""" |
|
|
|
|
# name = self.__class__.__name__ |
|
|
|
|
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") |
|
|
|
|
|
|
|
|
|
def smart_load(self, key): |
|
|
|
|
def _smart_load(self, key): |
|
|
|
|
"""Load model/trainer/validator/predictor.""" |
|
|
|
|
try: |
|
|
|
|
return self.task_map[self.task][key] |
|
|
|
|