Add initial model interface (#30)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/31/head
parent
7b560f7861
commit
1054819a59
12 changed files with 220 additions and 109 deletions
@ -0,0 +1,13 @@ |
|||||||
|
from ultralytics.yolo import YOLO |
||||||
|
|
||||||
|
|
||||||
|
def test_model(): |
||||||
|
model = YOLO() |
||||||
|
model.new("assets/dummy_model.yaml") |
||||||
|
model.model = "squeezenet1_0" # temp solution before get_model is implemented |
||||||
|
# model.load("yolov5n.pt") |
||||||
|
model.train(data="imagenette160", epochs=1, lr0=0.01) |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
test_model() |
@ -1,4 +1,7 @@ |
|||||||
|
import ultralytics.yolo.v8 as v8 |
||||||
|
|
||||||
|
from .engine.model import YOLO |
||||||
from .engine.trainer import BaseTrainer |
from .engine.trainer import BaseTrainer |
||||||
from .engine.validator import BaseValidator |
from .engine.validator import BaseValidator |
||||||
|
|
||||||
__all__ = ["BaseTrainer", "BaseValidator"] # allow simpler import |
__all__ = ["BaseTrainer", "BaseValidator", "YOLO"] # allow simpler import |
||||||
|
@ -0,0 +1,63 @@ |
|||||||
|
""" |
||||||
|
Top-level YOLO model interface. First principle usage example - https://github.com/ultralytics/ultralytics/issues/13 |
||||||
|
""" |
||||||
|
import torch |
||||||
|
import yaml |
||||||
|
|
||||||
|
import ultralytics.yolo as yolo |
||||||
|
from ultralytics.yolo.utils import LOGGER |
||||||
|
from ultralytics.yolo.utils.checks import check_yaml |
||||||
|
from ultralytics.yolo.utils.modeling.tasks import ClassificationModel, DetectionModel, SegmentationModel |
||||||
|
|
||||||
|
# map head: [model, trainer] |
||||||
|
MODEL_MAP = { |
||||||
|
"Classify": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], |
||||||
|
"Detect": [ClassificationModel, 'yolo.VERSION.classify.train.ClassificationTrainer'], # temp |
||||||
|
"Segment": []} |
||||||
|
|
||||||
|
|
||||||
|
class YOLO: |
||||||
|
|
||||||
|
def __init__(self, version=8) -> None: |
||||||
|
self.version = version |
||||||
|
self.model = None |
||||||
|
self.trainer = None |
||||||
|
self.pretrained_weights = None |
||||||
|
|
||||||
|
def new(self, cfg: str): |
||||||
|
cfg = check_yaml(cfg) # check YAML |
||||||
|
self.model, self.trainer = self._get_model_and_trainer(cfg) |
||||||
|
|
||||||
|
def load(self, weights, autodownload=True): |
||||||
|
if not isinstance(self.pretrained_weights, type(None)): |
||||||
|
LOGGER.info("Overwriting weights") |
||||||
|
# TODO: weights = smart_file_loader(weights) |
||||||
|
if self.model: |
||||||
|
self.model.load(weights) |
||||||
|
LOGGER.info("Checkpoint loaded successfully") |
||||||
|
else: |
||||||
|
# TODO: infer model and trainer |
||||||
|
pass |
||||||
|
|
||||||
|
self.pretrained_weights = weights |
||||||
|
|
||||||
|
def reset(self): |
||||||
|
pass |
||||||
|
|
||||||
|
def train(self, **kwargs): |
||||||
|
if 'data' not in kwargs: |
||||||
|
raise Exception("data is required to train") |
||||||
|
if not self.model: |
||||||
|
raise Exception("model not initialized. Use .new() or .load()") |
||||||
|
kwargs["model"] = self.model |
||||||
|
trainer = self.trainer(overrides=kwargs) |
||||||
|
trainer.train() |
||||||
|
|
||||||
|
def _get_model_and_trainer(self, cfg): |
||||||
|
with open(cfg, encoding='ascii', errors='ignore') as f: |
||||||
|
cfg = yaml.safe_load(f) # model dict |
||||||
|
model, trainer = MODEL_MAP[cfg["head"][-1][-2]] |
||||||
|
# warning: eval is unsafe. Use with caution |
||||||
|
trainer = eval(trainer.replace("VERSION", f"v{self.version}")) |
||||||
|
|
||||||
|
return model(cfg), trainer |
@ -1,3 +1,4 @@ |
|||||||
from ultralytics.yolo.v8.classify import train |
from ultralytics.yolo.v8.classify.train import ClassificationTrainer |
||||||
|
from ultralytics.yolo.v8.classify.val import ClassificationValidator |
||||||
|
|
||||||
__all__ = ["train"] |
__all__ = ["train"] |
||||||
|
Loading…
Reference in new issue