diff --git a/docs/en/reference/cfg/__init__.md b/docs/en/reference/cfg/__init__.md index b8db12c5d..c6627fd12 100644 --- a/docs/en/reference/cfg/__init__.md +++ b/docs/en/reference/cfg/__init__.md @@ -19,6 +19,10 @@ keywords: Ultralytics, YOLO, Configuration, cfg2dict, handle_deprecation, merge_

+## ::: ultralytics.cfg.check_cfg + +

+ ## ::: ultralytics.cfg.get_save_dir

diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 1a38daf9e..8e442f850 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.1.30" +__version__ = "8.1.31" from ultralytics.data.explorer.explorer import Explorer from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index d9d737581..de8347cf8 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -30,8 +30,8 @@ from ultralytics.utils import ( ) # Define valid tasks and modes -MODES = "train", "val", "predict", "export", "track", "benchmark" -TASKS = "detect", "segment", "classify", "pose", "obb" +MODES = {"train", "val", "predict", "export", "track", "benchmark"} +TASKS = {"detect", "segment", "classify", "pose", "obb"} TASK2DATA = { "detect": "coco8.yaml", "segment": "coco8-seg.yaml", @@ -93,8 +93,8 @@ CLI_HELP_MSG = f""" """ # Define keys for arg type checks -CFG_FLOAT_KEYS = "warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time" -CFG_FRACTION_KEYS = ( +CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time"} +CFG_FRACTION_KEYS = { "dropout", "iou", "lr0", @@ -118,8 +118,8 @@ CFG_FRACTION_KEYS = ( "conf", "iou", "fraction", -) # fraction floats 0.0 - 1.0 -CFG_INT_KEYS = ( +} # fraction floats 0.0 - 1.0 +CFG_INT_KEYS = { "epochs", "patience", "batch", @@ -133,8 +133,8 @@ CFG_INT_KEYS = ( "workspace", "nbs", "save_period", -) -CFG_BOOL_KEYS = ( +} +CFG_BOOL_KEYS = { "save", "exist_ok", "verbose", @@ -169,7 +169,7 @@ CFG_BOOL_KEYS = ( "nms", "profile", "multi_scale", -) +} def cfg2dict(cfg): @@ -219,33 +219,46 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.") # Type and Value checks + check_cfg(cfg) + + # Return instance + return IterableSimpleNamespace(**cfg) + + +def check_cfg(cfg, hard=True): + """Check Ultralytics configuration argument types and values.""" for k, v in cfg.items(): if v is not None: # None values may be from optional args if k in CFG_FLOAT_KEYS and not isinstance(v, (int, float)): - raise TypeError( - f"'{k}={v}' is of invalid type {type(v).__name__}. " - f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" - ) - elif k in CFG_FRACTION_KEYS: - if not isinstance(v, (int, float)): + if hard: raise TypeError( f"'{k}={v}' is of invalid type {type(v).__name__}. " f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" ) + cfg[k] = float(v) + elif k in CFG_FRACTION_KEYS: + if not isinstance(v, (int, float)): + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"Valid '{k}' types are int (i.e. '{k}=0') or float (i.e. '{k}=0.5')" + ) + cfg[k] = float(v) if not (0.0 <= v <= 1.0): raise ValueError(f"'{k}={v}' is an invalid value. " f"Valid '{k}' values are between 0.0 and 1.0.") elif k in CFG_INT_KEYS and not isinstance(v, int): - raise TypeError( - f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')" - ) + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " f"'{k}' must be an int (i.e. '{k}=8')" + ) + cfg[k] = int(v) elif k in CFG_BOOL_KEYS and not isinstance(v, bool): - raise TypeError( - f"'{k}={v}' is of invalid type {type(v).__name__}. " - f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" - ) - - # Return instance - return IterableSimpleNamespace(**cfg) + if hard: + raise TypeError( + f"'{k}={v}' is of invalid type {type(v).__name__}. " + f"'{k}' must be a bool (i.e. '{k}=True' or '{k}=False')" + ) + cfg[k] = bool(v) def get_save_dir(args, name=None): @@ -464,10 +477,10 @@ def entrypoint(debug=""): overrides = {} # basic overrides, i.e. imgsz=320 for a in merge_equals_args(args): # merge spaces around '=' sign if a.startswith("--"): - LOGGER.warning(f"WARNING ⚠️ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require leading dashes '--', updating to '{a[2:]}'.") a = a[2:] if a.endswith(","): - LOGGER.warning(f"WARNING ⚠️ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") + LOGGER.warning(f"WARNING ⚠️ argument '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.") a = a[:-1] if "=" in a: try: @@ -504,7 +517,7 @@ def entrypoint(debug=""): mode = overrides.get("mode") if mode is None: mode = DEFAULT_CFG.mode or "predict" - LOGGER.warning(f"WARNING ⚠️ 'mode' is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") + LOGGER.warning(f"WARNING ⚠️ 'mode' argument is missing. Valid modes are {MODES}. Using default 'mode={mode}'.") elif mode not in MODES: raise ValueError(f"Invalid 'mode={mode}'. Valid modes are {MODES}.\n{CLI_HELP_MSG}") @@ -520,7 +533,7 @@ def entrypoint(debug=""): model = overrides.pop("model", DEFAULT_CFG.model) if model is None: model = "yolov8n.pt" - LOGGER.warning(f"WARNING ⚠️ 'model' is missing. Using default 'model={model}'.") + LOGGER.warning(f"WARNING ⚠️ 'model' argument is missing. Using default 'model={model}'.") overrides["model"] = model stem = Path(model).stem.lower() if "rtdetr" in stem: # guess architecture @@ -554,15 +567,15 @@ def entrypoint(debug=""): # Mode if mode in ("predict", "track") and "source" not in overrides: overrides["source"] = DEFAULT_CFG.source or ASSETS - LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") + LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.") elif mode in ("train", "val"): if "data" not in overrides and "resume" not in overrides: overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) - LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.") + LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.") elif mode == "export": if "format" not in overrides: overrides["format"] = DEFAULT_CFG.format or "torchscript" - LOGGER.warning(f"WARNING ⚠️ 'format' is missing. Using default 'format={overrides['format']}'.") + LOGGER.warning(f"WARNING ⚠️ 'format' argument is missing. Using default 'format={overrides['format']}'.") # Run command in python getattr(model, mode)(**overrides) # default args from model diff --git a/ultralytics/data/build.py b/ultralytics/data/build.py index 37c5fa417..6bfb48f33 100644 --- a/ultralytics/data/build.py +++ b/ultralytics/data/build.py @@ -129,7 +129,7 @@ def check_source(source): webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False if isinstance(source, (str, int, Path)): # int for local usb camera source = str(source) - is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) + is_file = Path(source).suffix[1:] in (IMG_FORMATS | VID_FORMATS) is_url = source.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")) webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file) screenshot = source.lower() == "screen" diff --git a/ultralytics/data/utils.py b/ultralytics/data/utils.py index 1ce8127aa..c0a077368 100644 --- a/ultralytics/data/utils.py +++ b/ultralytics/data/utils.py @@ -35,8 +35,8 @@ from ultralytics.utils.downloads import download, safe_download, unzip_file from ultralytics.utils.ops import segments2boxes HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatting guidance." -IMG_FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm" # image suffixes -VID_FORMATS = "asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm" # video suffixes +IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes +VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py index d8d944f99..babe5d3b1 100644 --- a/ultralytics/engine/results.py +++ b/ultralytics/engine/results.py @@ -385,14 +385,12 @@ class Results(SimpleClass): BGR=True, ) - def tojson(self, normalize=False): - """Convert the object to JSON format.""" + def summary(self, normalize=False): + """Convert the results to a summarized format.""" if self.probs is not None: - LOGGER.warning("Warning: Classify task do not support `tojson` yet.") + LOGGER.warning("Warning: Classify task do not support `summary` and `tojson` yet.") return - import json - # Create list of detection dictionaries results = [] data = self.boxes.data.cpu().tolist() @@ -413,8 +411,13 @@ class Results(SimpleClass): result["keypoints"] = {"x": (x / w).tolist(), "y": (y / h).tolist(), "visible": visible.tolist()} results.append(result) - # Convert detections to JSON - return json.dumps(results, indent=2) + return results + + def tojson(self, normalize=False): + """Convert the results to JSON format.""" + import json + + return json.dumps(self.summary(normalize=normalize), indent=2) class Boxes(BaseTensor): diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index bebfdbe02..19514c658 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -509,14 +509,9 @@ class AutoBackend(nn.Module): # NCNN elif self.ncnn: mat_in = self.pyncnn.Mat(im[0].cpu().numpy()) - ex = self.net.create_extractor() - input_names, output_names = self.net.input_names(), self.net.output_names() - ex.input(input_names[0], mat_in) - y = [] - for output_name in output_names: - mat_out = self.pyncnn.Mat() - ex.extract(output_name, mat_out) - y.append(np.array(mat_out)[None]) + with self.net.create_extractor() as ex: + ex.input(self.net.input_names()[0], mat_in) + y = [np.array(ex.extract(x)[1])[None] for x in self.net.output_names()] # NVIDIA Triton Inference Server elif self.triton: diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 64ee7f503..bb1732fd8 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -560,7 +560,8 @@ class WorldModel(DetectionModel): def __init__(self, cfg="yolov8s-world.yaml", ch=3, nc=None, verbose=True): """Initialize YOLOv8 world model with given config and parameters.""" - self.txt_feats = torch.randn(1, nc or 80, 512) # placeholder + self.txt_feats = torch.randn(1, nc or 80, 512) # features placeholder + self.clip_model = None # CLIP model placeholder super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose) def set_classes(self, text): @@ -571,10 +572,11 @@ class WorldModel(DetectionModel): check_requirements("git+https://github.com/openai/CLIP.git") import clip - model, _ = clip.load("ViT-B/32") - device = next(model.parameters()).device + if not self.clip_model: + self.clip_model = clip.load("ViT-B/32")[0] + device = next(self.clip_model.parameters()).device text_token = clip.tokenize(text).to(device) - txt_feats = model.encode_text(text_token).to(dtype=torch.float32) + txt_feats = self.clip_model.encode_text(text_token).to(dtype=torch.float32) txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True) self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1]).detach() self.model[-1].nc = len(text) @@ -841,7 +843,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) args[j] = locals()[a] if a in locals() else ast.literal_eval(a) n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain - if m in ( + if m in { Classify, Conv, ConvTranspose, @@ -867,7 +869,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) DWConvTranspose2d, C3x, RepC3, - ): + }: c1, c2 = ch[f], args[0] if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output) c2 = make_divisible(min(c2, max_channels) * width, 8) @@ -883,7 +885,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) n = 1 elif m is AIFI: args = [ch[f], *args] - elif m in (HGStem, HGBlock): + elif m in {HGStem, HGBlock}: c1, cm, c2 = ch[f], args[0], args[1] args = [c1, cm, c2, *args[2:]] if m is HGBlock: @@ -895,7 +897,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) args = [ch[f]] elif m is Concat: c2 = sum(ch[x] for x in f) - elif m in (Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn): + elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn}: args.append([ch[x] for x in f]) if m is Segment: args[2] = make_divisible(min(args[2], max_channels) * width, 8) @@ -978,7 +980,7 @@ def guess_model_task(model): def cfg2task(cfg): """Guess from YAML dictionary.""" m = cfg["head"][-1][-2].lower() # output module name - if m in ("classify", "classifier", "cls", "fc"): + if m in {"classify", "classifier", "cls", "fc"}: return "classify" if m == "detect": return "detect"