diff --git a/docs/cfg.md b/docs/cfg.md index 5197f0b49..4ec5a5b63 100644 --- a/docs/cfg.md +++ b/docs/cfg.md @@ -55,7 +55,7 @@ include train, val, and predict. | mode | 'train' | YOLO mode, i.e. train, val, predict, or export | | resume | False | resume training from last checkpoint or custom checkpoint if passed as resume=path/to/best.pt | | model | null | path to model file, i.e. yolov8n.pt, yolov8n.yaml | -| data | null | path to data file, i.e. i.e. coco128.yaml | +| data | null | path to data file, i.e. coco128.yaml | ### Training @@ -69,7 +69,7 @@ task. | Key | Value | Description | |-----------------|--------|--------------------------------------------------------------------------------| | model | null | path to model file, i.e. yolov8n.pt, yolov8n.yaml | -| data | null | path to data file, i.e. i.e. coco128.yaml | +| data | null | path to data file, i.e. coco128.yaml | | epochs | 100 | number of epochs to train for | | patience | 50 | epochs to wait for no observable improvement for early stopping of training | | batch | 16 | number of images per batch (-1 for AutoBatch) | diff --git a/docs/predict.md b/docs/predict.md index e2a9d2379..a52b0d8ab 100644 --- a/docs/predict.md +++ b/docs/predict.md @@ -47,7 +47,7 @@ source can be used as a stream and the model argument required for that source. | CSV | | `'sources.csv'` | `str`, `Path` | RTSP, RTMP, HTTP | | video | ✓ | `'vid.mp4'` | `str`, `Path` | | | directory | ✓ | `'path/'` | `str`, `Path` | | -| glob | ✓ | `path/*.jpg'` | `str` | Use `*` operator | +| glob | ✓ | `'path/*.jpg'` | `str` | Use `*` operator | | YouTube | ✓ | `'https://youtu.be/Zgi9g1ksQHc'` | `str` | | | stream | ✓ | `'rtsp://example.com/media.mp4'` | `str` | RTSP, RTMP, HTTP | diff --git a/tests/test_python.py b/tests/test_python.py index d576b8e6d..243fbe571 100644 --- a/tests/test_python.py +++ b/tests/test_python.py @@ -49,6 +49,8 @@ def test_predict_dir(): def test_predict_img(): model = YOLO(MODEL) + seg_model = YOLO('yolov8n-seg.pt') + cls_model = YOLO('yolov8n-cls.pt') im = cv2.imread(str(SOURCE)) assert len(model(source=Image.open(SOURCE), save=True, verbose=True)) == 1 # PIL assert len(model(source=im, save=True, save_txt=True)) == 1 # ndarray @@ -64,6 +66,18 @@ def test_predict_img(): np.zeros((320, 640, 3))] # numpy assert len(model(batch)) == len(batch) # multiple sources in a batch + # Test tensor inference + im = cv2.imread(str(SOURCE)) # OpenCV + t = cv2.resize(im, (32, 32)) + t = torch.from_numpy(t.transpose((2, 0, 1))) + t = torch.stack([t, t, t, t]) + results = model(t) + assert len(results) == t.shape[0] + results = seg_model(t) + assert len(results) == t.shape[0] + results = cls_model(t) + assert len(results) == t.shape[0] + def test_predict_grey_and_4ch(): model = YOLO(MODEL) @@ -199,3 +213,6 @@ def test_result(): res = model(SOURCE) res[0].plot() print(res[0].path) + + +test_predict_img() diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index da14ed06c..403b610e8 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = '8.0.49' +__version__ = '8.0.50' from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.utils.checks import check_yolo as checks diff --git a/ultralytics/models/v5/yolov5l6u.yaml b/ultralytics/models/v5/yolov5l6u.yaml new file mode 100644 index 000000000..76da02ac9 --- /dev/null +++ b/ultralytics/models/v5/yolov5l6u.yaml @@ -0,0 +1,55 @@ +# Ultralytics YOLO 🚀, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [768, 3, 2]], # 7-P5/32 + [-1, 3, C3, [768]], + [-1, 1, Conv, [1024, 3, 2]], # 9-P6/64 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 11 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [768, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 8], 1, Concat, [1]], # cat backbone P5 + [-1, 3, C3, [768, False]], # 15 + + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 19 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 23 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 20], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 26 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 16], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [768, False]], # 29 (P5/32-large) + + [-1, 1, Conv, [768, 3, 2]], + [[-1, 12], 1, Concat, [1]], # cat head P6 + [-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge) + + [[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6) + ] diff --git a/ultralytics/models/v5/yolov5m6u.yaml b/ultralytics/models/v5/yolov5m6u.yaml new file mode 100644 index 000000000..84274ea07 --- /dev/null +++ b/ultralytics/models/v5/yolov5m6u.yaml @@ -0,0 +1,55 @@ +# Ultralytics YOLO 🚀, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 0.67 # model depth multiple +width_multiple: 0.75 # layer channel multiple + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [768, 3, 2]], # 7-P5/32 + [-1, 3, C3, [768]], + [-1, 1, Conv, [1024, 3, 2]], # 9-P6/64 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 11 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [768, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 8], 1, Concat, [1]], # cat backbone P5 + [-1, 3, C3, [768, False]], # 15 + + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 19 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 23 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 20], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 26 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 16], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [768, False]], # 29 (P5/32-large) + + [-1, 1, Conv, [768, 3, 2]], + [[-1, 12], 1, Concat, [1]], # cat head P6 + [-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge) + + [[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6) + ] diff --git a/ultralytics/models/v5/yolov5n6u.yaml b/ultralytics/models/v5/yolov5n6u.yaml new file mode 100644 index 000000000..5776879d9 --- /dev/null +++ b/ultralytics/models/v5/yolov5n6u.yaml @@ -0,0 +1,55 @@ +# Ultralytics YOLO 🚀, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.25 # layer channel multiple + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [768, 3, 2]], # 7-P5/32 + [-1, 3, C3, [768]], + [-1, 1, Conv, [1024, 3, 2]], # 9-P6/64 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 11 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [768, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 8], 1, Concat, [1]], # cat backbone P5 + [-1, 3, C3, [768, False]], # 15 + + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 19 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 23 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 20], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 26 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 16], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [768, False]], # 29 (P5/32-large) + + [-1, 1, Conv, [768, 3, 2]], + [[-1, 12], 1, Concat, [1]], # cat head P6 + [-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge) + + [[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6) + ] diff --git a/ultralytics/models/v5/yolov5s6u.yaml b/ultralytics/models/v5/yolov5s6u.yaml new file mode 100644 index 000000000..90a39c0b3 --- /dev/null +++ b/ultralytics/models/v5/yolov5s6u.yaml @@ -0,0 +1,55 @@ +# Ultralytics YOLO 🚀, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 0.33 # model depth multiple +width_multiple: 0.50 # layer channel multiple + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [768, 3, 2]], # 7-P5/32 + [-1, 3, C3, [768]], + [-1, 1, Conv, [1024, 3, 2]], # 9-P6/64 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 11 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [768, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 8], 1, Concat, [1]], # cat backbone P5 + [-1, 3, C3, [768, False]], # 15 + + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 19 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 23 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 20], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 26 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 16], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [768, False]], # 29 (P5/32-large) + + [-1, 1, Conv, [768, 3, 2]], + [[-1, 12], 1, Concat, [1]], # cat head P6 + [-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge) + + [[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6) + ] diff --git a/ultralytics/models/v5/yolov5su.yaml b/ultralytics/models/v5/yolov5su.yaml index 8ac4bf711..8cd3c5b93 100644 --- a/ultralytics/models/v5/yolov5su.yaml +++ b/ultralytics/models/v5/yolov5su.yaml @@ -5,7 +5,6 @@ nc: 80 # number of classes depth_multiple: 0.33 # model depth multiple width_multiple: 0.50 # layer channel multiple - # YOLOv5 v6.0 backbone backbone: # [from, number, module, args] diff --git a/ultralytics/models/v5/yolov5x6u.yaml b/ultralytics/models/v5/yolov5x6u.yaml new file mode 100644 index 000000000..31cd9c9ee --- /dev/null +++ b/ultralytics/models/v5/yolov5x6u.yaml @@ -0,0 +1,55 @@ +# Ultralytics YOLO 🚀, GPL-3.0 license + +# Parameters +nc: 80 # number of classes +depth_multiple: 1.33 # model depth multiple +width_multiple: 1.25 # layer channel multiple + +# YOLOv5 v6.0 backbone +backbone: + # [from, number, module, args] + [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2 + [-1, 1, Conv, [128, 3, 2]], # 1-P2/4 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 3-P3/8 + [-1, 6, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 5-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [768, 3, 2]], # 7-P5/32 + [-1, 3, C3, [768]], + [-1, 1, Conv, [1024, 3, 2]], # 9-P6/64 + [-1, 3, C3, [1024]], + [-1, 1, SPPF, [1024, 5]], # 11 + ] + +# YOLOv5 v6.0 head +head: + [[-1, 1, Conv, [768, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 8], 1, Concat, [1]], # cat backbone P5 + [-1, 3, C3, [768, False]], # 15 + + [-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 6], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 19 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 23 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 20], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 26 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 16], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [768, False]], # 29 (P5/32-large) + + [-1, 1, Conv, [768, 3, 2]], + [[-1, 12], 1, Concat, [1]], # cat head P6 + [-1, 3, C3, [1024, False]], # 32 (P6/64-xlarge) + + [[23, 26, 29, 32], 1, Detect, [nc]], # Detect(P3, P4, P5, P6) + ] diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 81dc43345..9b7aa1a76 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -27,6 +27,10 @@ def check_class_names(names): if isinstance(names, dict): if not all(isinstance(k, int) for k in names.keys()): # convert string keys to int, i.e. '0' to 0 names = {int(k): v for k, v in names.items()} + n = len(names) + if max(names.keys()) >= n: + raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices ' + f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.') if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764' map = yaml_load(ROOT / 'datasets/ImageNet.yaml')['map'] # human-readable names names = {k: map[v] for k, v in names.items()} @@ -35,12 +39,14 @@ def check_class_names(names): class AutoBackend(nn.Module): - def _apply_default_class_names(self, data): - with contextlib.suppress(Exception): - return yaml_load(check_yaml(data))['names'] - return {i: f'class{i}' for i in range(999)} # return default if above errors - - def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True): + def __init__(self, + weights='yolov8n.pt', + device=torch.device('cpu'), + dnn=False, + data=None, + fp16=False, + fuse=True, + verbose=True): """ MultiBackend class for python inference on various platforms using Ultralytics YOLO. @@ -51,6 +57,7 @@ class AutoBackend(nn.Module): data (str), (Path): Additional data.yaml file for class names, optional fp16 (bool): If True, use half precision. Default: False fuse (bool): Whether to fuse the model or not. Default: True + verbose (bool): Whether to run in verbose mode or not. Default: True Supported formats and their naming conventions: | Format | Suffix | @@ -83,7 +90,7 @@ class AutoBackend(nn.Module): # NOTE: special case: in-memory pytorch model if nn_module: model = weights.to(device) - model = model.fuse() if fuse else model + model = model.fuse(verbose=verbose) if fuse else model names = model.module.names if hasattr(model, 'module') else model.names # get class names stride = max(int(model.stride.max()), 32) # model stride model.half() if fp16 else model.float() @@ -410,6 +417,12 @@ class AutoBackend(nn.Module): for _ in range(2 if self.jit else 1): # self.forward(im) # warmup + @staticmethod + def _apply_default_class_names(data): + with contextlib.suppress(Exception): + return yaml_load(check_yaml(data))['names'] + return {i: f'class{i}' for i in range(999)} # return default if above errors + @staticmethod def _model_type(p='path/to/model.pt'): """ diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index e56ea813a..9cab0c41c 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -8,9 +8,7 @@ import thop import torch import torch.nn as nn -from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify, - Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus, - GhostBottleneck, GhostConv, Segment) +from ultralytics.nn.modules import * # noqa: F403 from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, colorstr, emojis, yaml_load from ultralytics.yolo.utils.checks import check_requirements, check_yaml from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, @@ -87,7 +85,7 @@ class BaseModel(nn.Module): if c: LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total") - def fuse(self): + def fuse(self, verbose=True): """ Fuse the `Conv2d()` and `BatchNorm2d()` layers of the model into a single layer, in order to improve the computation efficiency. @@ -105,7 +103,7 @@ class BaseModel(nn.Module): m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn) delattr(m, 'bn') # remove batchnorm m.forward = m.forward_fuse # update forward - self.info() + self.info(verbose=verbose) return self @@ -130,7 +128,7 @@ class BaseModel(nn.Module): verbose (bool): if True, prints out the model information. Defaults to False imgsz (int): the size of the image that the model will be trained on. Defaults to 640 """ - model_info(self, verbose, imgsz) + model_info(self, verbose=verbose, imgsz=imgsz) def _apply(self, fn): """ @@ -437,7 +435,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) ch = [ch] layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args - m = eval(m) if isinstance(m, str) else m # eval strings + m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module for j, a in enumerate(args): # TODO: re-implement with eval() removal if possible # args[j] = (locals()[a] if a in locals() else ast.literal_eval(a)) if isinstance(a, str) else a diff --git a/ultralytics/yolo/cfg/__init__.py b/ultralytics/yolo/cfg/__init__.py index ea25e87bd..c38c93d63 100644 --- a/ultralytics/yolo/cfg/__init__.py +++ b/ultralytics/yolo/cfg/__init__.py @@ -61,8 +61,10 @@ CFG_BOOL_KEYS = ('save', 'exist_ok', 'pretrained', 'verbose', 'deterministic', ' 'v5loader') # Define valid tasks and modes -TASKS = 'detect', 'segment', 'classify' MODES = 'train', 'val', 'predict', 'export', 'track', 'benchmark' +TASKS = 'detect', 'segment', 'classify' +TASK2DATA = {'detect': 'coco128.yaml', 'segment': 'coco128-seg.yaml', 'classify': 'imagenet100'} +TASK2MODEL = {'detect': 'yolov8n.pt', 'segment': 'yolov8n-seg.pt', 'classify': 'yolov8n-cls.pt'} def cfg2dict(cfg): @@ -274,8 +276,11 @@ def entrypoint(debug=''): # Task task = overrides.pop('task', None) - if task and task not in TASKS: - raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") + if task: + if task not in TASKS: + raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}") + if 'model' not in overrides: + overrides['model'] = TASK2MODEL[task] # Model model = overrides.pop('model', DEFAULT_CFG.model) @@ -287,9 +292,10 @@ def entrypoint(debug=''): model = YOLO(model, task=task) # Task Update - if task and task != model.task: - LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " - f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.") + if task != model.task: + if task: + LOGGER.warning(f"WARNING ⚠️ conflicting 'task={task}' passed with 'task={model.task}' model. " + f"Ignoring 'task={task}' and updating to 'task={model.task}' to match model.") task = model.task # Mode @@ -299,8 +305,7 @@ def entrypoint(debug=''): LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using default 'source={overrides['source']}'.") elif mode in ('train', 'val'): if 'data' not in overrides: - task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100') - overrides['data'] = task2data.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) + overrides['data'] = TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data) LOGGER.warning(f"WARNING ⚠️ 'data' is missing. Using default 'data={overrides['data']}'.") elif mode == 'export': if 'format' not in overrides: @@ -322,4 +327,4 @@ def copy_default_cfg(): if __name__ == '__main__': # entrypoint(debug='yolo predict model=yolov8n.pt') - entrypoint(debug='') + entrypoint(debug='yolo train model=yolov8n-seg.pt') diff --git a/ultralytics/yolo/cfg/default.yaml b/ultralytics/yolo/cfg/default.yaml index fd7ad9bec..f1e6e97a9 100644 --- a/ultralytics/yolo/cfg/default.yaml +++ b/ultralytics/yolo/cfg/default.yaml @@ -6,7 +6,7 @@ mode: train # YOLO mode, i.e. train, val, predict, export # Train settings ------------------------------------------------------------------------------------------------------- model: # path to model file, i.e. yolov8n.pt, yolov8n.yaml -data: # path to data file, i.e. i.e. coco128.yaml +data: # path to data file, i.e. coco128.yaml epochs: 100 # number of epochs to train for patience: 50 # epochs to wait for no observable improvement for early stopping of training batch: 16 # number of images per batch (-1 for AutoBatch) diff --git a/ultralytics/yolo/data/base.py b/ultralytics/yolo/data/base.py index 75577514b..cd4f9203f 100644 --- a/ultralytics/yolo/data/base.py +++ b/ultralytics/yolo/data/base.py @@ -35,7 +35,8 @@ class BaseDataset(Dataset): batch_size=None, stride=32, pad=0.5, - single_cls=False): + single_cls=False, + classes=None): super().__init__() self.img_path = img_path self.imgsz = imgsz @@ -45,8 +46,7 @@ class BaseDataset(Dataset): self.im_files = self.get_img_files(self.img_path) self.labels = self.get_labels() - if self.single_cls: - self.update_labels(include_class=[]) + self.update_labels(include_class=classes) # single_cls and include_class self.ni = len(self.labels) @@ -96,7 +96,7 @@ class BaseDataset(Dataset): """include_class, filter labels to include only these classes (optional)""" include_class_array = np.array(include_class).reshape(1, -1) for i in range(len(self.labels)): - if include_class: + if include_class is not None: cls = self.labels[i]['cls'] bboxes = self.labels[i]['bboxes'] segments = self.labels[i]['segments'] @@ -104,7 +104,7 @@ class BaseDataset(Dataset): self.labels[i]['cls'] = cls[j] self.labels[i]['bboxes'] = bboxes[j] if segments: - self.labels[i]['segments'] = segments[j] + self.labels[i]['segments'] = [segments[si] for si, idx in enumerate(j) if idx] if self.single_cls: self.labels[i]['cls'][:, 0] = 0 diff --git a/ultralytics/yolo/data/build.py b/ultralytics/yolo/data/build.py index 6a5fc1cc5..b79e661b2 100644 --- a/ultralytics/yolo/data/build.py +++ b/ultralytics/yolo/data/build.py @@ -10,7 +10,7 @@ from PIL import Image from torch.utils.data import DataLoader, dataloader, distributed from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots, - LoadStreams, SourceTypes, autocast_list) + LoadStreams, LoadTensor, SourceTypes, autocast_list) from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS from ultralytics.yolo.utils.checks import check_file @@ -82,7 +82,8 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra prefix=colorstr(f'{mode}: '), use_segments=cfg.task == 'segment', use_keypoints=cfg.task == 'keypoint', - names=names) + names=names, + classes=cfg.classes) batch = min(batch, len(dataset)) nd = torch.cuda.device_count() # number of CUDA devices @@ -133,7 +134,7 @@ def build_classification_dataloader(path, def check_source(source): - webcam, screenshot, from_img, in_memory = False, False, False, False + webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False if isinstance(source, (str, int, Path)): # int for local usb camera source = str(source) is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS) @@ -149,22 +150,25 @@ def check_source(source): from_img = True elif isinstance(source, (Image.Image, np.ndarray)): from_img = True + elif isinstance(source, torch.Tensor): + tensor = True else: raise TypeError('Unsupported image type. See docs for supported types https://docs.ultralytics.com/predict') - return source, webcam, screenshot, from_img, in_memory + return source, webcam, screenshot, from_img, in_memory, tensor def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True): """ TODO: docs """ - # source - source, webcam, screenshot, from_img, in_memory = check_source(source) - source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img) + source, webcam, screenshot, from_img, in_memory, tensor = check_source(source) + source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor) # Dataloader - if in_memory: + if tensor: + dataset = LoadTensor(source) + elif in_memory: dataset = source elif webcam: dataset = LoadStreams(source, diff --git a/ultralytics/yolo/data/dataloaders/stream_loaders.py b/ultralytics/yolo/data/dataloaders/stream_loaders.py index 613850a0d..9794f56dc 100644 --- a/ultralytics/yolo/data/dataloaders/stream_loaders.py +++ b/ultralytics/yolo/data/dataloaders/stream_loaders.py @@ -26,6 +26,7 @@ class SourceTypes: webcam: bool = False screenshot: bool = False from_img: bool = False + tensor: bool = False class LoadStreams: @@ -329,6 +330,23 @@ class LoadPilAndNumpy: return self +class LoadTensor: + + def __init__(self, imgs) -> None: + self.im0 = imgs + self.bs = imgs.shape[0] + + def __iter__(self): + self.count = 0 + return self + + def __next__(self): + if self.count == 1: + raise StopIteration + self.count += 1 + return None, self.im0, self.im0, None, '' # self.paths, im, self.im0, None, '' + + def autocast_list(source): """ Merges a list of source of different types into a list of numpy arrays or PIL images diff --git a/ultralytics/yolo/data/dataloaders/v5loader.py b/ultralytics/yolo/data/dataloaders/v5loader.py index 803935778..4a6c709da 100644 --- a/ultralytics/yolo/data/dataloaders/v5loader.py +++ b/ultralytics/yolo/data/dataloaders/v5loader.py @@ -539,7 +539,7 @@ class LoadImagesAndLabels(Dataset): j = (label[:, 0:1] == include_class_array).any(1) self.labels[i] = label[j] if segment: - self.segments[i] = segment[j] + self.segments[i] = [segment[si] for si, idx in enumerate(j) if idx] if single_cls: # single-class training, merge all classes into 0 self.labels[i][:, 0] = 0 diff --git a/ultralytics/yolo/data/dataset.py b/ultralytics/yolo/data/dataset.py index abb743c6f..af1123e38 100644 --- a/ultralytics/yolo/data/dataset.py +++ b/ultralytics/yolo/data/dataset.py @@ -57,12 +57,14 @@ class YOLODataset(BaseDataset): single_cls=False, use_segments=False, use_keypoints=False, - names=None): + names=None, + classes=None): self.use_segments = use_segments self.use_keypoints = use_keypoints self.names = names assert not (self.use_segments and self.use_keypoints), 'Can not use both segments and keypoints.' - super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls) + super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls, + classes) def cache_labels(self, path=Path('./labels.cache')): """Cache dataset labels, check images and read shapes. diff --git a/ultralytics/yolo/data/utils.py b/ultralytics/yolo/data/utils.py index fe23ec372..5a96fb920 100644 --- a/ultralytics/yolo/data/utils.py +++ b/ultralytics/yolo/data/utils.py @@ -16,6 +16,7 @@ import numpy as np from PIL import ExifTags, Image, ImageOps from tqdm import tqdm +from ultralytics.nn.autobackend import check_class_names from ultralytics.yolo.utils import DATASETS_DIR, LOGGER, NUM_THREADS, ROOT, colorstr, emojis, yaml_load from ultralytics.yolo.utils.checks import check_file, check_font, is_ascii from ultralytics.yolo.utils.downloads import download, safe_download, unzip_file @@ -211,8 +212,7 @@ def check_det_dataset(dataset, autodownload=True): raise SyntaxError( emojis(f"{dataset} '{k}:' key missing ❌.\n" f"'train', 'val' and 'names' are required in data.yaml files.")) - if isinstance(data['names'], (list, tuple)): # old array format - data['names'] = dict(enumerate(data['names'])) # convert to dict + data['names'] = check_class_names(data['names']) data['nc'] = len(data['names']) # Resolve paths diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index c2517224b..477b6e287 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -574,7 +574,7 @@ class Exporter: LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') saved_model = Path(str(self.file).replace(self.file.suffix, '_saved_model')) if self.args.int8: - f = saved_model / (self.file.stem + 'yolov8n_integer_quant.tflite') # fp32 in/out + f = saved_model / (self.file.stem + '_integer_quant.tflite') # fp32 in/out elif self.args.half: f = saved_model / (self.file.stem + '_float16.tflite') else: @@ -863,18 +863,6 @@ def export(cfg=DEFAULT_CFG): cfg.model = cfg.model or 'yolov8n.yaml' cfg.format = cfg.format or 'torchscript' - # exporter = Exporter(cfg) - # - # model = None - # if isinstance(cfg.model, (str, Path)): - # if Path(cfg.model).suffix == '.yaml': - # model = DetectionModel(cfg.model) - # elif Path(cfg.model).suffix == '.pt': - # model = attempt_load_weights(cfg.model, fuse=True) - # else: - # TypeError(f'Unsupported model type {cfg.model}') - # exporter(model=model) - from ultralytics import YOLO model = YOLO(cfg.model) model.export(**vars(cfg)) diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index d782c6d13..b86dca728 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -203,6 +203,8 @@ class YOLO: if source is None: source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") + is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and \ + ('predict' in sys.argv or 'mode=predict' in sys.argv) overrides = self.overrides.copy() overrides['conf'] = 0.25 @@ -213,10 +215,9 @@ class YOLO: if not self.predictor: self.task = overrides.get('task') or self.task self.predictor = TASK_MAP[self.task][3](overrides=overrides) - self.predictor.setup_model(model=self.model) + self.predictor.setup_model(model=self.model, verbose=is_cli) else: # only update args if predictor is already setup self.predictor.args = get_cfg(self.predictor.args, overrides) - is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics') return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream) def track(self, source=None, stream=False, **kwargs): diff --git a/ultralytics/yolo/engine/predictor.py b/ultralytics/yolo/engine/predictor.py index de15d5199..66342bd16 100644 --- a/ultralytics/yolo/engine/predictor.py +++ b/ultralytics/yolo/engine/predictor.py @@ -183,6 +183,8 @@ class BasePredictor: 'preprocess': self.dt[0].dt * 1E3 / n, 'inference': self.dt[1].dt * 1E3 / n, 'postprocess': self.dt[2].dt * 1E3 / n} + if self.source_type.tensor: # skip write, show and plot operations if input is raw tensor + continue p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \ else (path, im0s.copy()) p = Path(p) @@ -218,11 +220,16 @@ class BasePredictor: self.run_callbacks('on_predict_end') - def setup_model(self, model): - device = select_device(self.args.device) + def setup_model(self, model, verbose=True): + device = select_device(self.args.device, verbose=verbose) model = model or self.args.model self.args.half &= device.type != 'cpu' # half precision only supported on CUDA - self.model = AutoBackend(model, device=device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half) + self.model = AutoBackend(model, + device=device, + dnn=self.args.dnn, + data=self.args.data, + fp16=self.args.half, + verbose=verbose) self.device = device self.model.eval() diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index a0ce21518..27abb7f24 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -25,8 +25,8 @@ from tqdm import tqdm from ultralytics.nn.tasks import attempt_load_one_weight, attempt_load_weights from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.data.utils import check_cls_dataset, check_det_dataset -from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, __version__, callbacks, - colorstr, emojis, yaml_save) +from ultralytics.yolo.utils import (DEFAULT_CFG, LOGGER, ONLINE, RANK, ROOT, SETTINGS, TQDM_BAR_FORMAT, __version__, + callbacks, colorstr, emojis, yaml_save) from ultralytics.yolo.utils.autobatch import check_train_batch_size from ultralytics.yolo.utils.checks import check_file, check_imgsz, print_args from ultralytics.yolo.utils.dist import ddp_cleanup, generate_ddp_command @@ -111,8 +111,6 @@ class BaseTrainer: print_args(vars(self.args)) # Device - self.amp = self.device.type != 'cpu' - self.scaler = amp.GradScaler(enabled=self.amp) if self.device.type == 'cpu': self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading @@ -126,7 +124,7 @@ class BaseTrainer: if 'yaml_file' in self.data: self.args.data = self.data['yaml_file'] # for validating 'yolo train data=url.zip' usage except Exception as e: - raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e + raise RuntimeError(emojis(f"Dataset '{self.args.data}' error ❌ {e}")) from e self.trainset, self.testset = self.get_dataset(self.data) self.ema = None @@ -204,6 +202,8 @@ class BaseTrainer: ckpt = self.setup_model() self.model = self.model.to(self.device) self.set_model_attributes() + self.amp = check_amp(self.model) + self.scaler = amp.GradScaler(enabled=self.amp) if world_size > 1: self.model = DDP(self.model, device_ids=[rank]) # Check imgsz @@ -597,3 +597,31 @@ class BaseTrainer: LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups " f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias') return optimizer + + +def check_amp(model): + # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation + device = next(model.parameters()).device # get model device + if device.type in ('cpu', 'mps'): + return False # AMP only used on CUDA devices + + def amp_allclose(m, im): + # All close FP32 vs AMP results + a = m(im, device=device, verbose=False)[0].boxes.boxes # FP32 inference + with torch.cuda.amp.autocast(True): + b = m(im, device=device, verbose=False)[0].boxes.boxes # AMP inference + return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.1) # close to 10% absolute tolerance + + f = ROOT / 'assets/bus.jpg' # image to check + im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3)) + prefix = colorstr('AMP: ') + try: + from ultralytics import YOLO + LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...') + assert amp_allclose(YOLO('yolov8n.pt'), im) + LOGGER.info(f'{prefix}checks passed ✅') + return True + except AssertionError: + LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to ' + f'NaN losses or zero-mAP results, so AMP will be disabled during training.') + return False diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index edeea3626..4d3a9088a 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -236,9 +236,10 @@ def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''): def check_yolov5u_filename(file: str, verbose: bool = True): # Replace legacy YOLOv5 filenames with updated YOLOv5u filenames - if 'yolov3' in file or 'yolov5' in file and 'u' not in file: + if ('yolov3' in file or 'yolov5' in file) and 'u' not in file: original_file = file file = re.sub(r'(.*yolov5([nsmlx]))\.', '\\1u.', file) # i.e. yolov5n.pt -> yolov5nu.pt + file = re.sub(r'(.*yolov5([nsmlx])6)\.', '\\1u.', file) # i.e. yolov5n6.pt -> yolov5n6u.pt file = re.sub(r'(.*yolov3(|-tiny|-spp))\.', '\\1u.', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt if file != original_file and verbose: LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " diff --git a/ultralytics/yolo/utils/torch_utils.py b/ultralytics/yolo/utils/torch_utils.py index e918808df..61c7b721b 100644 --- a/ultralytics/yolo/utils/torch_utils.py +++ b/ultralytics/yolo/utils/torch_utils.py @@ -162,11 +162,13 @@ def fuse_deconv_and_bn(deconv, bn): return fuseddconv -def model_info(model, verbose=False, imgsz=640): +def model_info(model, detailed=False, verbose=True, imgsz=640): # Model information. imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320] + if not verbose: + return n_p = get_num_params(model) n_g = get_num_gradients(model) # number gradients - if verbose: + if detailed: LOGGER.info( f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}") for i, (name, p) in enumerate(model.named_parameters()): diff --git a/ultralytics/yolo/v8/classify/predict.py b/ultralytics/yolo/v8/classify/predict.py index 36a01eb36..790fcee69 100644 --- a/ultralytics/yolo/v8/classify/predict.py +++ b/ultralytics/yolo/v8/classify/predict.py @@ -14,14 +14,13 @@ class ClassificationPredictor(BasePredictor): return Annotator(img, example=str(self.model.names), pil=True) def preprocess(self, img): - img = (img if isinstance(img, torch.Tensor) else torch.Tensor(img)).to(self.model.device) - img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 - return img + img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) + return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 - def postprocess(self, preds, img, orig_img): + def postprocess(self, preds, img, orig_imgs): results = [] for i, pred in enumerate(preds): - orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img + orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs path, _, _, _, _ = self.batch img_path = path[i] if isinstance(path, list) else path results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, probs=pred)) diff --git a/ultralytics/yolo/v8/detect/predict.py b/ultralytics/yolo/v8/detect/predict.py index ecd51697d..c83f39c72 100644 --- a/ultralytics/yolo/v8/detect/predict.py +++ b/ultralytics/yolo/v8/detect/predict.py @@ -14,12 +14,12 @@ class DetectionPredictor(BasePredictor): return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names)) def preprocess(self, img): - img = torch.from_numpy(img).to(self.model.device) + img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device) img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32 img /= 255 # 0 - 255 to 0.0 - 1.0 return img - def postprocess(self, preds, img, orig_img): + def postprocess(self, preds, img, orig_imgs): preds = ops.non_max_suppression(preds, self.args.conf, self.args.iou, @@ -29,7 +29,7 @@ class DetectionPredictor(BasePredictor): results = [] for i, pred in enumerate(preds): - orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img + orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs shape = orig_img.shape pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round() path, _, _, _, _ = self.batch diff --git a/ultralytics/yolo/v8/segment/predict.py b/ultralytics/yolo/v8/segment/predict.py index 51110e275..2c004f0aa 100644 --- a/ultralytics/yolo/v8/segment/predict.py +++ b/ultralytics/yolo/v8/segment/predict.py @@ -10,7 +10,7 @@ from ultralytics.yolo.v8.detect.predict import DetectionPredictor class SegmentationPredictor(DetectionPredictor): - def postprocess(self, preds, img, orig_img): + def postprocess(self, preds, img, orig_imgs): # TODO: filter by classes p = ops.non_max_suppression(preds[0], self.args.conf, @@ -22,7 +22,7 @@ class SegmentationPredictor(DetectionPredictor): results = [] proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported for i, pred in enumerate(p): - orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img + orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs shape = orig_img.shape path, _, _, _, _ = self.batch img_path = path[i] if isinstance(path, list) else path