Add initial model interface (#30)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/31/head
parent
7b560f7861
commit
1054819a59
12 changed files with 220 additions and 109 deletions
@ -0,0 +1,13 @@ |
||||
from ultralytics.yolo import YOLO |
||||
|
||||
|
||||
def test_model(): |
||||
model = YOLO() |
||||
model.new("assets/dummy_model.yaml") |
||||
model.model = "squeezenet1_0" # temp solution before get_model is implemented |
||||
# model.load("yolov5n.pt") |
||||
model.train(data="imagenette160", epochs=1, lr0=0.01) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
test_model() |
@ -1,4 +1,7 @@ |
||||
import ultralytics.yolo.v8 as v8 |
||||
|
||||
from .engine.model import YOLO |
||||
from .engine.trainer import BaseTrainer |
||||
from .engine.validator import BaseValidator |
||||
|
||||
__all__ = ["BaseTrainer", "BaseValidator"] # allow simpler import |
||||
__all__ = ["BaseTrainer", "BaseValidator", "YOLO"] # allow simpler import |
||||
|
@ -0,0 +1,63 @@ |
||||
""" |
||||
Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13 |
||||
""" |
||||
import torch |
||||
import yaml |
||||
|
||||
import ultralytics.yolo as yolo |
||||
from ultralytics.yolo.utils import LOGGER |
||||
from ultralytics.yolo.utils.checks import check_yaml |
||||
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel |
||||
|
||||
# map head: [model, trainer] |
||||
MODEL_MAP = { |
||||
"Classify": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], |
||||
"Detect": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], # temp |
||||
"Segment": []} |
||||
|
||||
|
||||
class YOLO: |
||||
|
||||
def __init__(self, version=8) -> None: |
||||
self.version = version |
||||
self.model = None |
||||
self.trainer = None |
||||
self.pretrained_weights = None |
||||
|
||||
def new(self, cfg: str): |
||||
cfg = check_yaml(cfg) # check YAML |
||||
self.model, self.trainer = self._get_model_and_trainer(cfg) |
||||
|
||||
def load(self, weights, autodownload=True): |
||||
if not isinstance(self.pretrained_weights, type(None)): |
||||
LOGGER.info("Overwriting weights") |
||||
# TODO: weights = smart_file_loader(weights) |
||||
if self.model: |
||||
self.model.load(weights) |
||||
LOGGER.info("Checkpoint loaded successfully") |
||||
else: |
||||
# TODO: infer model and trainer |
||||
pass |
||||
|
||||
self.pretrained_weights = weights |
||||
|
||||
def reset(self): |
||||
pass |
||||
|
||||
def train(self, **kwargs): |
||||
if 'data' not in kwargs: |
||||
raise Exception("data is required to train") |
||||
if not self.model: |
||||
raise Exception("model not initialized. Use .new() or .load()") |
||||
kwargs["model"] = self.model |
||||
trainer = self.trainer(overrides=kwargs) |
||||
trainer.train() |
||||
|
||||
def _get_model_and_trainer(self, cfg): |
||||
with open(cfg, encoding='ascii', errors='ignore') as f: |
||||
cfg = yaml.safe_load(f) # model dict |
||||
model, trainer = MODEL_MAP[cfg["head"][-1][-2]] |
||||
# warning: eval is unsafe. Use with caution |
||||
trainer = eval(trainer.replace("VERSION", f"v{self.version}")) |
||||
|
||||
return model(cfg), trainer |
@ -1,3 +1,4 @@ |
||||
from ultralytics.yolo.v8.classify import train |
||||
from ultralytics.yolo.v8.classify.train import ClassificationTrainer |
||||
from ultralytics.yolo.v8.classify.val import ClassificationValidator |
||||
|
||||
__all__ = ["train"] |
||||
|
Loading…
Reference in new issue