|
|
|
@ -3,12 +3,13 @@ |
|
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
from ultralytics import yolo # noqa |
|
|
|
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight |
|
|
|
|
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight, |
|
|
|
|
guess_model_task) |
|
|
|
|
from ultralytics.yolo.cfg import get_cfg |
|
|
|
|
from ultralytics.yolo.engine.exporter import Exporter |
|
|
|
|
from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, yaml_load |
|
|
|
|
from ultralytics.yolo.utils.checks import check_yaml |
|
|
|
|
from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, smart_inference_mode |
|
|
|
|
from ultralytics.yolo.utils.torch_utils import smart_inference_mode |
|
|
|
|
|
|
|
|
|
# Map head to model, trainer, validator, and predictor classes |
|
|
|
|
MODEL_MAP = { |
|
|
|
@ -73,9 +74,9 @@ class YOLO: |
|
|
|
|
""" |
|
|
|
|
cfg = check_yaml(cfg) # check YAML |
|
|
|
|
cfg_dict = yaml_load(cfg, append_filename=True) # model dict |
|
|
|
|
self.task = guess_task_from_model_yaml(cfg_dict) |
|
|
|
|
self.task = guess_model_task(cfg_dict) |
|
|
|
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \ |
|
|
|
|
self._guess_ops_from_task(self.task) |
|
|
|
|
self._assign_ops_from_task(self.task) |
|
|
|
|
self.model = self.ModelClass(cfg_dict, verbose=verbose) # initialize |
|
|
|
|
self.cfg = cfg |
|
|
|
|
|
|
|
|
@ -92,7 +93,7 @@ class YOLO: |
|
|
|
|
self.overrides = self.model.args |
|
|
|
|
self._reset_ckpt_args(self.overrides) |
|
|
|
|
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = \ |
|
|
|
|
self._guess_ops_from_task(self.task) |
|
|
|
|
self._assign_ops_from_task(self.task) |
|
|
|
|
|
|
|
|
|
def reset(self): |
|
|
|
|
""" |
|
|
|
@ -217,7 +218,7 @@ class YOLO: |
|
|
|
|
""" |
|
|
|
|
self.model.to(device) |
|
|
|
|
|
|
|
|
|
def _guess_ops_from_task(self, task): |
|
|
|
|
def _assign_ops_from_task(self, task): |
|
|
|
|
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[task] |
|
|
|
|
# warning: eval is unsafe. Use with caution |
|
|
|
|
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}")) |
|
|
|
|