""" Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13 """ import yaml import ultralytics.yolo as yolo from ultralytics.yolo.utils import LOGGER from ultralytics.yolo.utils.checks import check_yaml from ultralytics.yolo.utils.modeling import get_model from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel # map head: [model, trainer] MODEL_MAP = { "classify": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], "detect": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], # temp "segment": []} class YOLO: def __init__(self, task=None, version=8) -> None: self.version = version self.ModelClass = None self.TrainerClass = None self.model = None self.pretrained_weights = None if task: if task.lower() not in MODEL_MAP: raise Exception(f"Unsupported task {task}. The supported tasks are: \n {MODEL_MAP.keys()}") self.ModelClass, self.TrainerClass = MODEL_MAP[task] self.TrainerClass = eval(self.trainer.replace("VERSION", f"v{self.version}")) def new(self, cfg: str): cfg = check_yaml(cfg) # check YAML if self.model: self.model = self.model(cfg) else: with open(cfg, encoding='ascii', errors='ignore') as f: cfg = yaml.safe_load(f) # model dict self.ModelClass, self.TrainerClass = self._get_model_and_trainer(cfg["head"]) self.model = self.ModelClass(cfg) # initialize def load(self, weights, autodownload=True): if not isinstance(self.pretrained_weights, type(None)): LOGGER.info("Overwriting weights") # TODO: weights = smart_file_loader(weights) if self.model: self.model.load(weights) LOGGER.info("Checkpoint loaded successfully") else: self.model = get_model(weights) self.ModelClass, self.TrainerClass = self._guess_model_and_trainer(list(self.model.named_children())) self.pretrained_weights = weights def reset(self): for m in self.model.modules(): if hasattr(m, 'reset_parameters'): m.reset_parameters() for p in self.model.parameters(): p.requires_grad = True def train(self, **kwargs): if 'data' not in kwargs: raise Exception("data is required to train") if not self.model: raise Exception("model not initialized. Use .new() or .load()") # kwargs["model"] = self.model trainer = self.TrainerClass(overrides=kwargs) trainer.model = self.model trainer.train() def _guess_model_and_trainer(self, cfg): # TODO: warn head = cfg[-1][-2] if head.lower() in ["classify", "classifier", "cls", "fc"]: task = "classify" if head.lower() in ["detect"]: task = "detect" if head.lower() in ["segment"]: task = "segment" model_class, trainer_class = MODEL_MAP[task] # warning: eval is unsafe. Use with caution trainer_class = eval(trainer_class.replace("VERSION", f"v{self.version}")) return model_class, trainer_class if __name__ == "__main__": model = YOLO() # model.new("assets/dummy_model.yaml") model.load("yolov5n-cls.pt") model.train(data="imagenette160", epochs=1, lr0=0.01)