|
|
|
@ -47,6 +47,7 @@ class YOLO: |
|
|
|
|
self.trainer = None # trainer object |
|
|
|
|
self.task = None # task type |
|
|
|
|
self.ckpt = None # if loaded from *.pt |
|
|
|
|
self.ckpt_path = None |
|
|
|
|
self.cfg = None # if loaded from *.yaml |
|
|
|
|
self.overrides = {} # overrides for trainer object |
|
|
|
|
self.init_disabled = False # disable model initialization |
|
|
|
@ -78,6 +79,7 @@ class YOLO: |
|
|
|
|
weights (str): model checkpoint to be loaded |
|
|
|
|
""" |
|
|
|
|
self.model = attempt_load_weights(weights) |
|
|
|
|
self.ckpt_path = weights |
|
|
|
|
self.task = self.model.args["task"] |
|
|
|
|
self.overrides = self.model.args |
|
|
|
|
self.overrides["device"] = '' # reset device |
|
|
|
@ -177,8 +179,8 @@ class YOLO: |
|
|
|
|
""" |
|
|
|
|
if not self.model: |
|
|
|
|
raise AttributeError("model not initialized. Use .new() or .load()") |
|
|
|
|
|
|
|
|
|
overrides = kwargs |
|
|
|
|
overrides = self.overrides.copy() |
|
|
|
|
overrides.update(kwargs) |
|
|
|
|
if kwargs.get("cfg"): |
|
|
|
|
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.") |
|
|
|
|
overrides = yaml_load(check_yaml(kwargs["cfg"])) |
|
|
|
@ -187,10 +189,13 @@ class YOLO: |
|
|
|
|
if not overrides.get("data"): |
|
|
|
|
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.") |
|
|
|
|
|
|
|
|
|
if overrides.get("resume"): |
|
|
|
|
overrides["resume"] = self.ckpt_path |
|
|
|
|
self.trainer = self.TrainerClass(overrides=overrides) |
|
|
|
|
self.trainer.model = self.trainer.load_model(weights=self.model, |
|
|
|
|
model_cfg=self.model.yaml if self.task != "classify" else None) |
|
|
|
|
self.model = self.trainer.model # override here to save memory |
|
|
|
|
if not overrides.get("resume"): |
|
|
|
|
self.trainer.model = self.trainer.load_model(weights=self.model, |
|
|
|
|
model_cfg=self.model.yaml if self.task != "classify" else None) |
|
|
|
|
self.model = self.trainer.model # override here to save memory |
|
|
|
|
|
|
|
|
|
self.trainer.train() |
|
|
|
|
|
|
|
|
|