diff --git a/docker/Dockerfile b/docker/Dockerfile index de3e101017..3bbecd52fa 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -3,18 +3,18 @@ # Image is CUDA-optimized for YOLOv8 single/multi-GPU training and inference # Start FROM NVIDIA PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch -FROM nvcr.io/nvidia/pytorch:23.01-py3 +# FROM docker.io/pytorch/pytorch:latest +FROM pytorch/pytorch:latest # Downloads to user config dir ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/ -# Remove torch nightly and install torch stable -RUN rm -rf /opt/pytorch # remove 1.2GB dir -RUN pip uninstall -y torchtext pillow torch torchvision -RUN pip install --no-cache torch torchvision - # Install linux packages -RUN apt update && apt install --no-install-recommends -y zip htop screen libgl1-mesa-glx +ENV DEBIAN_FRONTEND noninteractive +RUN apt update +RUN TZ=Etc/UTC apt install -y tzdata +RUN apt install --no-install-recommends -y git zip curl htop libgl1-mesa-glx libglib2.0-0 libpython3-dev gnupg +# RUN alias python=python3 # Create working directory RUN mkdir -p /usr/src/ultralytics @@ -25,12 +25,18 @@ WORKDIR /usr/src/ultralytics RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics # Install pip packages -RUN python -m pip install --upgrade pip wheel -RUN pip install --no-cache ultralytics albumentations comet gsutil notebook +COPY requirements.txt . +RUN python3 -m pip install --upgrade pip wheel +RUN pip install --no-cache ultralytics albumentations comet gsutil notebook \ + coremltools onnx onnx-simplifier onnxruntime openvino-dev>=2022.3 + # tensorflow tensorflowjs \ # Set environment variables ENV OMP_NUM_THREADS=1 +# Cleanup +ENV DEBIAN_FRONTEND teletype + # Usage Examples ------------------------------------------------------------------------------------------------------- diff --git a/docker/Dockerfile-cpu b/docker/Dockerfile-cpu index dc9143d74f..0b585e803a 100644 --- a/docker/Dockerfile-cpu +++ b/docker/Dockerfile-cpu @@ -27,8 +27,8 @@ RUN git clone https://github.com/ultralytics/ultralytics /usr/src/ultralytics COPY requirements.txt . RUN python3 -m pip install --upgrade pip wheel RUN pip install --no-cache ultralytics albumentations gsutil notebook \ - coremltools onnx onnx-simplifier onnxruntime tensorflow-cpu \ - # openvino-dev>=2022.3 tensorflowjs \ + coremltools onnx onnx-simplifier onnxruntime openvino-dev>=2022.3 \ + # tensorflow-cpu tensorflowjs \ --extra-index-url https://download.pytorch.org/whl/cpu # Cleanup diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index eef5efa3b5..8575f0736b 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, GPL-3.0 license -__version__ = "8.0.26" +__version__ = "8.0.27" from ultralytics.yolo.engine.model import YOLO from ultralytics.yolo.utils import ops diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index f32b4bf394..9248735dca 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -18,6 +18,16 @@ from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url from ultralytics.yolo.utils.ops import xywh2xyxy +def check_class_names(names): + # Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts. + if isinstance(names, list): # names is a list + names = dict(enumerate(names)) # convert to dict + if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764' + map = yaml_load(ROOT / 'yolo/data/datasets/ImageNet.yaml')['map'] # human-readable names + names = {k: map[v] for k, v in names.items()} + return names + + class AutoBackend(nn.Module): def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True): @@ -228,11 +238,7 @@ class AutoBackend(nn.Module): # class names if 'names' not in locals(): # names missing names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)} # assign default - elif isinstance(names, list): # names is a list - names = dict(enumerate(names)) # convert to dict - if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764' - map = yaml_load(ROOT / 'yolo/data/datasets/ImageNet.yaml')['map'] # human-readable names - names = {k: map[v] for k, v in names.items()} + names = check_class_names(names) self.__dict__.update(locals()) # assign all variables to self diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 183670a410..8150b20df4 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -347,23 +347,24 @@ def torch_safe_load(weight): def attempt_load_weights(weights, device=None, inplace=True, fuse=False): # Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a - model = Ensemble() + ensemble = Ensemble() for w in weights if isinstance(weights, list) else [weights]: ckpt = torch_safe_load(w) # load ckpt args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args - ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model + model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model # Model compatibility updates - ckpt.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model - ckpt.pt_path = weights # attach *.pt file path to model - if not hasattr(ckpt, 'stride'): - ckpt.stride = torch.tensor([32.]) + model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model + model.pt_path = weights # attach *.pt file path to model + model.task = guess_model_task(model) + if not hasattr(model, 'stride'): + model.stride = torch.tensor([32.]) # Append - model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode + ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode # Module compatibility updates - for m in model.modules(): + for m in ensemble.modules(): t = type(m) if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment): m.inplace = inplace # torch 1.7.0 compatibility @@ -371,16 +372,16 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False): m.recompute_scale_factor = None # torch 1.11.0 compatibility # Return model - if len(model) == 1: - return model[-1] + if len(ensemble) == 1: + return ensemble[-1] # Return ensemble print(f'Ensemble created with {weights}\n') for k in 'names', 'nc', 'yaml': - setattr(model, k, getattr(model[0], k)) - model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride - assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}' - return model + setattr(ensemble, k, getattr(ensemble[0], k)) + ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride + assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts: {[m.nc for m in ensemble]}' + return ensemble def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): @@ -392,6 +393,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False): # Model compatibility updates model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model model.pt_path = weight # attach *.pt file path to model + model.task = guess_model_task(model) if not hasattr(model, 'stride'): model.stride = torch.tensor([32.]) diff --git a/ultralytics/yolo/data/datasets/ImageNet.yaml b/ultralytics/yolo/data/datasets/ImageNet.yaml index fc2fc52752..c42c0eb5ab 100644 --- a/ultralytics/yolo/data/datasets/ImageNet.yaml +++ b/ultralytics/yolo/data/datasets/ImageNet.yaml @@ -1153,7 +1153,7 @@ map: n02009229: little_blue_heron n02009912: American_egret n02011460: bittern - n02012849: crane + n02012849: crane_(bird) n02013706: limpkin n02017213: European_gallinule n02018207: American_coot @@ -1536,7 +1536,7 @@ map: n03124043: cowboy_boot n03124170: cowboy_hat n03125729: cradle - n03126707: crane + n03126707: crane_(machine) n03127747: crash_helmet n03127925: crate n03131574: crib @@ -1657,8 +1657,8 @@ map: n03706229: magnetic_compass n03709823: mailbag n03710193: mailbox - n03710637: maillot - n03710721: maillot + n03710637: maillot_(tights) + n03710721: maillot_(tank_suit) n03717622: manhole_cover n03720891: maraca n03721384: marimba diff --git a/ultralytics/yolo/engine/exporter.py b/ultralytics/yolo/engine/exporter.py index ba67e61cc9..5e3123770f 100644 --- a/ultralytics/yolo/engine/exporter.py +++ b/ultralytics/yolo/engine/exporter.py @@ -65,6 +65,7 @@ import pandas as pd import torch import ultralytics +from ultralytics.nn.autobackend import check_class_names from ultralytics.nn.modules import Detect, Segment from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, guess_model_task from ultralytics.yolo.cfg import get_cfg @@ -151,9 +152,12 @@ class Exporter: assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic' # Checks + model.names = check_class_names(model.names) # if self.args.batch == model.args['batch_size']: # user has not modified training batch_size self.args.batch = 1 self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size + if model.task == 'classify': + self.args.nms = self.args.agnostic_nms = False if self.args.optimize: assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu' @@ -194,8 +198,14 @@ class Exporter: self.model = model self.file = file self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else (x.shape for x in y) - self.metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata self.pretty_name = self.file.stem.replace('yolo', 'YOLO') + self.metadata = { + 'description': f"Ultralytics {self.pretty_name} model trained on {self.model.args['data']}", + 'author': 'Ultralytics', + 'license': 'GPL-3.0 https://ultralytics.com/license', + 'version': ultralytics.__version__, + 'stride': int(max(model.stride)), + 'names': model.names} # model metadata # Exports f = [''] * len(fmts) # exported filenames @@ -235,12 +245,11 @@ class Exporter: # Finish f = [str(x) for x in f if x] # filter out '' and None if any(f): - task = guess_model_task(model) s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models" LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' f"\nResults saved to {colorstr('bold', file.parent.resolve())}" - f"\nPredict: yolo task={task} mode=predict model={f[-1]} {s}" - f"\nValidate: yolo task={task} mode=val model={f[-1]} {s}" + f"\nPredict: yolo task={model.task} mode=predict model={f[-1]} {s}" + f"\nValidate: yolo task={model.task} mode=val model={f[-1]} {s}" f"\nVisualize: https://netron.app") self.run_callbacks("on_export_end") @@ -375,9 +384,13 @@ class Exporter: LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') f = self.file.with_suffix('.mlmodel') + task = self.model.task model = iOSModel(self.model, self.im).eval() if self.args.nms else self.model ts = torch.jit.trace(model, self.im, strict=False) # TorchScript model - ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=self.im.shape, scale=1 / 255, bias=[0, 0, 0])]) + classifier_config = ct.ClassifierConfig(list(model.names.values())) if task == 'classify' else None + ct_model = ct.convert(ts, + inputs=[ct.ImageType('image', shape=self.im.shape, scale=1 / 255, bias=[0, 0, 0])], + classifier_config=classifier_config) bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None) if bits < 32: if MACOS: # quantization only supported on macOS @@ -387,6 +400,10 @@ class Exporter: if self.args.nms: ct_model = self._pipeline_coreml(ct_model) + ct_model.short_description = self.metadata['description'] + ct_model.author = self.metadata['author'] + ct_model.license = self.metadata['license'] + ct_model.version = self.metadata['version'] ct_model.save(str(f)) return f, ct_model @@ -687,8 +704,8 @@ class Exporter: out0_shape = out[out0.name].shape out1_shape = out[out1.name].shape else: # linux and windows can not run model.predict(), get sizes from pytorch output y - out0_shape = self.output_shape[1], self.output_shape[2] - 5 # (3780, 80) - out1_shape = self.output_shape[1], 4 # (3780, 4) + out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80) + out1_shape = self.output_shape[2], 4 # (3780, 4) # Checks names = self.metadata['names'] @@ -714,7 +731,7 @@ class Exporter: # flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r) # Print - print(spec.description) + # print(spec.description) # Model from spec model = ct.models.MLModel(spec) @@ -771,10 +788,6 @@ class Exporter: # Update metadata pipeline.spec.specificationVersion = 5 - pipeline.spec.description.metadata.versionString = f'Ultralytics YOLOv{ultralytics.__version__}' - pipeline.spec.description.metadata.shortDescription = f'Ultralytics {self.pretty_name} CoreML model' - pipeline.spec.description.metadata.author = 'Ultralytics (https://ultralytics.com)' - pipeline.spec.description.metadata.license = 'GPL-3.0 license (https://ultralytics.com/license)' pipeline.spec.description.metadata.userDefined.update({ 'IoU threshold': str(nms.iouThreshold), 'Confidence threshold': str(nms.confidenceThreshold)}) diff --git a/ultralytics/yolo/engine/model.py b/ultralytics/yolo/engine/model.py index f6a5552fd3..5107d7330a 100644 --- a/ultralytics/yolo/engine/model.py +++ b/ultralytics/yolo/engine/model.py @@ -1,6 +1,7 @@ # Ultralytics YOLO 🚀, GPL-3.0 license from pathlib import Path +from typing import List import sys from ultralytics import yolo # noqa @@ -9,7 +10,7 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat from ultralytics.yolo.cfg import get_cfg from ultralytics.yolo.engine.exporter import Exporter from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, RANK, callbacks, yaml_load -from ultralytics.yolo.utils.checks import check_yaml +from ultralytics.yolo.utils.checks import check_yaml, check_imgsz from ultralytics.yolo.utils.torch_utils import smart_inference_mode # Map head to model, trainer, validator, and predictor classes @@ -131,7 +132,7 @@ class YOLO: Check the 'configuration' section in the documentation for all available options. Returns: - (dict): The prediction results. + (List[ultralytics.yolo.engine.results.Results]): The prediction results. """ overrides = self.overrides.copy() overrides["conf"] = 0.25 @@ -161,6 +162,7 @@ class YOLO: args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) args.data = data or args.data args.task = self.task + args.imgsz = check_imgsz(args.imgsz, max_dim=1) validator = self.ValidatorClass(args=args) validator(model=self.model) diff --git a/ultralytics/yolo/engine/trainer.py b/ultralytics/yolo/engine/trainer.py index 0dd49e094b..333598231b 100644 --- a/ultralytics/yolo/engine/trainer.py +++ b/ultralytics/yolo/engine/trainer.py @@ -202,7 +202,7 @@ class BaseTrainer: self.model = DDP(self.model, device_ids=[rank]) # Check imgsz gs = max(int(self.model.stride.max() if hasattr(self.model, 'stride') else 32), 32) # grid size (max stride) - self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs) + self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1) # Batch size if self.batch_size == -1: if RANK == -1: # single-GPU only, estimate best batch size diff --git a/ultralytics/yolo/utils/__init__.py b/ultralytics/yolo/utils/__init__.py index 5023a182a2..4a045592b7 100644 --- a/ultralytics/yolo/utils/__init__.py +++ b/ultralytics/yolo/utils/__init__.py @@ -467,6 +467,13 @@ def set_sentry(): """ def before_send(event, hint): + if 'exc_info' in hint: + exc_type, exc_value, tb = hint['exc_info'] + if exc_type in (KeyboardInterrupt, FileNotFoundError) \ + or 'out of memory' in str(exc_value) \ + or not sys.argv[0].endswith('yolo'): + return None # do not send event + env = 'Colab' if is_colab() else 'Kaggle' if is_kaggle() else 'Jupyter' if is_jupyter() else \ 'Docker' if is_docker() else platform.system() event['tags'] = { @@ -477,6 +484,7 @@ def set_sentry(): return event if SETTINGS['sync'] and \ + RANK in {-1, 0} and \ not is_pytest_running() and \ not is_github_actions_ci() and \ ((is_pip_package() and not is_git_dir()) or @@ -491,7 +499,7 @@ def set_sentry(): release=ultralytics.__version__, environment='production', # 'dev' or 'production' before_send=before_send, - ignore_errors=[KeyboardInterrupt]) + ignore_errors=[KeyboardInterrupt, FileNotFoundError]) # Disable all sentry logging for logger in "sentry_sdk", "sentry_sdk.errors": diff --git a/ultralytics/yolo/utils/checks.py b/ultralytics/yolo/utils/checks.py index e864b1b05a..2fa129f2fb 100644 --- a/ultralytics/yolo/utils/checks.py +++ b/ultralytics/yolo/utils/checks.py @@ -40,7 +40,7 @@ def is_ascii(s) -> bool: return all(ord(c) < 128 for c in s) -def check_imgsz(imgsz, stride=32, min_dim=1, floor=0): +def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): """ Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. @@ -66,6 +66,13 @@ def check_imgsz(imgsz, stride=32, min_dim=1, floor=0): raise TypeError(f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'") + # Apply max_dim + if max_dim == 1: + LOGGER.warning(f"WARNING ⚠️ 'train' and 'val' imgsz types must be integer, updating to 'imgsz={max(imgsz)}'. " + f"'predict' and 'export' imgsz may be list or integer, " + f"i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'") + imgsz = [max(imgsz)] + # Make image size a multiple of the stride sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]