Cli support (#50)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>pull/51/head
parent
4291b9c31c
commit
512a225ce8
9 changed files with 70 additions and 17 deletions
@ -1 +1 @@ |
||||
__version__ = "0.0.1.dev0" |
||||
__version__ = "8.0.0.dev0" |
||||
|
@ -1,7 +1,39 @@ |
||||
import ultralytics.yolo.v8 as v8 |
||||
import hydra |
||||
|
||||
import ultralytics |
||||
import ultralytics.yolo.v8 as yolo |
||||
|
||||
from .engine.model import YOLO |
||||
from .engine.trainer import BaseTrainer |
||||
from .engine.trainer import DEFAULT_CONFIG, BaseTrainer |
||||
from .engine.validator import BaseValidator |
||||
from .utils import LOGGER |
||||
|
||||
__all__ = ["BaseTrainer", "BaseValidator", "YOLO"] # allow simpler import |
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path="utils/configs", config_name="default") |
||||
def cli(cfg): |
||||
LOGGER.info(f"using Ultralytics YOLO v{ultralytics.__version__}") |
||||
module_file = None |
||||
if cfg.task.lower() == "detect": |
||||
module_file = yolo.detect |
||||
elif cfg.task.lower() == "segment": |
||||
module_file = yolo.segment |
||||
elif cfg.task.lower() == "classify": |
||||
module_file = yolo.classify |
||||
|
||||
if not module_file: |
||||
raise Exception("task not recognized. Choices are `'detect', 'segment', 'classify'`") |
||||
|
||||
module_function = None |
||||
|
||||
if cfg.mode.lower() == "train": |
||||
module_function = module_file.train |
||||
elif cfg.mode.lower() == "val": |
||||
module_function = module_file.val |
||||
elif cfg.mode.lower() == "infer": |
||||
module_function = module_file.infer |
||||
|
||||
if not module_function: |
||||
raise Exception("mode not recognized. Choices are `'train', 'val', 'infer'`") |
||||
module_function(cfg) |
||||
|
@ -1,4 +1,4 @@ |
||||
from ultralytics.yolo.v8.classify.train import ClassificationTrainer |
||||
from ultralytics.yolo.v8.classify.train import ClassificationTrainer, train |
||||
from ultralytics.yolo.v8.classify.val import ClassificationValidator |
||||
|
||||
__all__ = ["train"] |
||||
|
@ -1,2 +1,2 @@ |
||||
from ultralytics.yolo.v8.segment.train import SegmentationTrainer |
||||
from ultralytics.yolo.v8.segment.train import SegmentationTrainer, train |
||||
from ultralytics.yolo.v8.segment.val import SegmentationValidator |
||||
|
Loading…
Reference in new issue