|
|
|
@ -1,7 +1,6 @@ |
|
|
|
|
# Ultralytics YOLO 🚀, GPL-3.0 license |
|
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
from copy import deepcopy |
|
|
|
|
from pathlib import Path |
|
|
|
|
from typing import Union |
|
|
|
|
|
|
|
|
@ -78,7 +77,7 @@ class YOLO: |
|
|
|
|
task (Any, optional): Task type for the YOLO model. Defaults to None. |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
self.callbacks = deepcopy(callbacks.default_callbacks) |
|
|
|
|
self._reset_callbacks() |
|
|
|
|
self.predictor = None # reuse predictor |
|
|
|
|
self.model = None # model object |
|
|
|
|
self.trainer = None # trainer object |
|
|
|
@ -118,7 +117,7 @@ class YOLO: |
|
|
|
|
return any(( |
|
|
|
|
model.startswith('https://hub.ultralytics.com/models/'), |
|
|
|
|
[len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID |
|
|
|
|
(len(model) == 20 and not Path(model).exists() and not any(x in model for x in './\\')))) # MODELID |
|
|
|
|
len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID |
|
|
|
|
|
|
|
|
|
def _new(self, cfg: str, task=None, verbose=True): |
|
|
|
|
""" |
|
|
|
@ -228,8 +227,8 @@ class YOLO: |
|
|
|
|
if source is None: |
|
|
|
|
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' |
|
|
|
|
LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") |
|
|
|
|
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and ( |
|
|
|
|
('predict' in sys.argv or 'mode=predict' in sys.argv) or ('track' in sys.argv or 'mode=track' in sys.argv)) |
|
|
|
|
is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any( |
|
|
|
|
x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track')) |
|
|
|
|
overrides = self.overrides.copy() |
|
|
|
|
overrides['conf'] = 0.25 |
|
|
|
|
overrides.update(kwargs) # prefer kwargs |
|
|
|
@ -238,7 +237,7 @@ class YOLO: |
|
|
|
|
overrides['save'] = kwargs.get('save', False) # not save files by default |
|
|
|
|
if not self.predictor: |
|
|
|
|
self.task = overrides.get('task') or self.task |
|
|
|
|
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks) |
|
|
|
|
self.predictor = TASK_MAP[self.task][3](overrides=overrides) |
|
|
|
|
self.predictor.setup_model(model=self.model, verbose=is_cli) |
|
|
|
|
else: # only update args if predictor is already setup |
|
|
|
|
self.predictor.args = get_cfg(self.predictor.args, overrides) |
|
|
|
@ -387,17 +386,19 @@ class YOLO: |
|
|
|
|
""" |
|
|
|
|
return self.model.transforms if hasattr(self.model, 'transforms') else None |
|
|
|
|
|
|
|
|
|
def add_callback(self, event: str, func): |
|
|
|
|
@staticmethod |
|
|
|
|
def add_callback(event: str, func): |
|
|
|
|
""" |
|
|
|
|
Add callback |
|
|
|
|
""" |
|
|
|
|
self.callbacks[event].append(func) |
|
|
|
|
callbacks.default_callbacks[event].append(func) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def _reset_ckpt_args(args): |
|
|
|
|
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model |
|
|
|
|
return {k: v for k, v in args.items() if k in include} |
|
|
|
|
|
|
|
|
|
def _reset_callbacks(self): |
|
|
|
|
@staticmethod |
|
|
|
|
def _reset_callbacks(): |
|
|
|
|
for event in callbacks.default_callbacks.keys(): |
|
|
|
|
self.callbacks[event] = [callbacks.default_callbacks[event][0]] |
|
|
|
|
callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]] |
|
|
|
|