`ultralytics 8.0.46` TFLite and Benchmarks updates (#1141)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/1051/head^2 v8.0.46
Glenn Jocher 2 years ago committed by GitHub
parent 3765f4f6d9
commit a82ee2c779
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 20
      .github/workflows/ci.yaml
  2. 2
      ultralytics/__init__.py
  3. 20
      ultralytics/yolo/cfg/__init__.py
  4. 36
      ultralytics/yolo/data/dataset.py
  5. 8
      ultralytics/yolo/engine/exporter.py
  6. 22
      ultralytics/yolo/engine/model.py
  7. 7
      ultralytics/yolo/engine/predictor.py
  8. 18
      ultralytics/yolo/utils/__init__.py
  9. 29
      ultralytics/yolo/utils/benchmarks.py
  10. 26
      ultralytics/yolo/utils/checks.py
  11. 16
      ultralytics/yolo/utils/dist.py

@ -55,18 +55,20 @@ jobs:
- name: Benchmark DetectionModel
shell: python
run: |
from ultralytics.yolo.utils.benchmarks import run_benchmarks
run_benchmarks(model='${{ matrix.model }}.pt', imgsz=160, half=False, hard_fail=False)
from ultralytics.yolo.utils.benchmarks import benchmark
benchmark(model='${{ matrix.model }}.pt', imgsz=160, half=False, hard_fail=0.20)
- name: Benchmark SegmentationModel
shell: python
run: |
from ultralytics.yolo.utils.benchmarks import run_benchmarks
run_benchmarks(model='${{ matrix.model }}-seg.pt', imgsz=160, half=False, hard_fail=False)
from ultralytics.yolo.utils.benchmarks import benchmark
benchmark(model='${{ matrix.model }}-seg.pt', imgsz=160, half=False, hard_fail=0.14)
- name: Benchmark ClassificationModel
shell: python
run: |
from ultralytics.yolo.utils.benchmarks import run_benchmarks
run_benchmarks(model='${{ matrix.model }}-cls.pt', imgsz=160, half=False, hard_fail=False)
from ultralytics.yolo.utils.benchmarks import benchmark
benchmark(model='${{ matrix.model }}-cls.pt', imgsz=160, half=False, hard_fail=0.70)
- name: Benchmark Summary
run: cat benchmarks.log
Tests:
timeout-minutes: 60
@ -88,10 +90,10 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Get cache dir
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
- name: Get cache dir # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
id: pip-cache
run: echo "::set-output name=dir::$(pip cache dir)"
run: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash # for Windows compatibility
- name: Cache pip
uses: actions/cache@v3
with:

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = '8.0.45'
__version__ = '8.0.46'
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils.checks import check_yolo as checks

@ -254,8 +254,8 @@ def entrypoint(debug=''):
else:
check_cfg_mismatch(full_args_dict, {a: ''})
# Defaults
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100')
# Check keys
check_cfg_mismatch(full_args_dict, overrides)
# Mode
mode = overrides.get('mode', None)
@ -279,11 +279,12 @@ def entrypoint(debug=''):
model = YOLO(model)
# Task
task = overrides.get('task', None)
if task is not None and task not in TASKS:
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
else:
model.task = task
task = overrides.get('task', model.task)
if task is not None:
if task not in TASKS:
raise ValueError(f"Invalid 'task={task}'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}")
else:
model.task = task
# Mode
if mode in {'predict', 'track'} and 'source' not in overrides:
@ -292,8 +293,9 @@ def entrypoint(debug=''):
LOGGER.warning(f"WARNING ⚠ 'source' is missing. Using default 'source={overrides['source']}'.")
elif mode in ('train', 'val'):
if 'data' not in overrides:
overrides['data'] = task2data.get(overrides['task'], DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠ 'data' is missing. Using {model.task} default 'data={overrides['data']}'.")
task2data = dict(detect='coco128.yaml', segment='coco128-seg.yaml', classify='imagenet100')
overrides['data'] = task2data.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
LOGGER.warning(f"WARNING ⚠ 'data' is missing. Using default 'data={overrides['data']}'.")
elif mode == 'export':
if 'format' not in overrides:
overrides['format'] = DEFAULT_CFG.format or 'torchscript'

@ -16,10 +16,28 @@ from .utils import HELP_URL, LOCAL_RANK, get_hash, img2label_paths, verify_image
class YOLODataset(BaseDataset):
cache_version = '1.0.1' # dataset labels *.cache version, >= 1.0.0 for YOLOv8
rand_interp_methods = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
"""YOLO Dataset.
"""
Dataset class for loading images object detection and/or segmentation labels in YOLO format.
Args:
img_path (str): image path.
prefix (str): prefix.
img_path (str): path to the folder containing images.
imgsz (int): image size (default: 640).
cache (bool): if True, a cache file of the labels is created to speed up future creation of dataset instances
(default: False).
augment (bool): if True, data augmentation is applied (default: True).
hyp (dict): hyperparameters to apply data augmentation (default: None).
prefix (str): prefix to print in log messages (default: '').
rect (bool): if True, rectangular training is used (default: False).
batch_size (int): size of batches (default: None).
stride (int): stride (default: 32).
pad (float): padding (default: 0.0).
single_cls (bool): if True, single class training is used (default: False).
use_segments (bool): if True, segmentation masks are used as labels (default: False).
use_keypoints (bool): if True, keypoints are used as labels (default: False).
names (list): class names (default: None).
Returns:
A PyTorch dataset object that can be used for training an object detection or segmentation model.
"""
def __init__(self,
@ -44,7 +62,12 @@ class YOLODataset(BaseDataset):
super().__init__(img_path, imgsz, cache, augment, hyp, prefix, rect, batch_size, stride, pad, single_cls)
def cache_labels(self, path=Path('./labels.cache')):
# Cache dataset labels, check images and read shapes
"""Cache dataset labels, check images and read shapes.
Args:
path (Path): path where to save the cache file (default: Path('./labels.cache')).
Returns:
(dict): labels.
"""
x = {'labels': []}
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
desc = f'{self.prefix}Scanning {path.parent / path.stem}...'
@ -119,9 +142,8 @@ class YOLODataset(BaseDataset):
self.im_files = [lb['im_file'] for lb in labels] # update im_files
# Check if the dataset is all boxes or all segments
len_cls = sum(len(lb['cls']) for lb in labels)
len_boxes = sum(len(lb['bboxes']) for lb in labels)
len_segments = sum(len(lb['segments']) for lb in labels)
lengths = ((len(lb['cls']), len(lb['bboxes']), len(lb['segments'])) for lb in labels)
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
if len_segments and len_boxes != len_segments:
LOGGER.warning(
f'WARNING ⚠ Box and segment counts should be equal, but got len(segments) = {len_segments}, '

@ -294,7 +294,7 @@ class Exporter:
# YOLOv8 ONNX export
requirements = ['onnx>=1.12.0']
if self.args.simplify:
requirements += ['onnxsim', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
requirements += ['onnxsim>=0.4.17', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime']
check_requirements(requirements)
import onnx # noqa
@ -513,8 +513,8 @@ class Exporter:
cuda = torch.cuda.is_available()
check_requirements(f"tensorflow{'-macos' if MACOS else '-aarch64' if ARM64 else '' if cuda else '-cpu'}")
import tensorflow as tf # noqa
check_requirements(('onnx', 'onnx2tf', 'sng4onnx', 'onnxsim', 'onnx_graphsurgeon', 'tflite_support',
'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),
check_requirements(('onnx', 'onnx2tf>=1.7.7', 'sng4onnx>=1.0.1', 'onnxsim>=0.4.17', 'onnx_graphsurgeon>=0.3.26',
'tflite_support', 'onnxruntime-gpu' if torch.cuda.is_available() else 'onnxruntime'),
cmds='--extra-index-url https://pypi.ngc.nvidia.com')
LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
@ -529,7 +529,7 @@ class Exporter:
# Export to TF
int8 = '-oiqt -qt per-tensor' if self.args.int8 else ''
cmd = f'onnx2tf -i {f_onnx} -o {f} --non_verbose {int8}'
cmd = f'onnx2tf -i {f_onnx} -o {f} -nuo --non_verbose {int8}'
LOGGER.info(f'\n{prefix} running {cmd}')
subprocess.run(cmd, shell=True)
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml

@ -9,8 +9,9 @@ from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, Segmentat
guess_model_task, nn)
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.engine.exporter import Exporter
from ultralytics.yolo.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, callbacks, yaml_load
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_yaml
from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
is_git_dir, is_pip_package, yaml_load)
from ultralytics.yolo.utils.checks import check_file, check_imgsz, check_pip_update, check_yaml
from ultralytics.yolo.utils.downloads import GITHUB_ASSET_STEMS
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
@ -150,6 +151,13 @@ class YOLO:
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
def _check_pip_update(self):
"""
Inform user of ultralytics package update availability
"""
if is_pip_package():
check_pip_update()
def reset(self):
"""
Resets the model modules.
@ -189,6 +197,10 @@ class YOLO:
Returns:
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
"""
if source is None:
source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
LOGGER.warning(f"WARNING ⚠ 'source' is missing. Using 'source={source}'.")
overrides = self.overrides.copy()
overrides['conf'] = 0.25
overrides.update(kwargs) # prefer kwargs
@ -251,11 +263,12 @@ class YOLO:
Args:
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
from ultralytics.yolo.utils.benchmarks import run_benchmarks
self._check_is_pytorch_model()
from ultralytics.yolo.utils.benchmarks import benchmark
overrides = self.model.args.copy()
overrides.update(kwargs)
overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
return run_benchmarks(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
return benchmark(model=self, imgsz=overrides['imgsz'], half=overrides['half'], device=overrides['device'])
def export(self, **kwargs):
"""
@ -283,6 +296,7 @@ class YOLO:
**kwargs (Any): Any number of arguments representing the training configuration.
"""
self._check_is_pytorch_model()
self._check_pip_update()
overrides = self.overrides.copy()
overrides.update(kwargs)
if kwargs.get('cfg'):

@ -178,7 +178,12 @@ class BasePredictor:
self.run_callbacks('on_predict_postprocess_end')
# visualize, save, write results
for i in range(len(im)):
n = len(im)
for i in range(n):
self.results[i].speed = {
'preprocess': self.dt[0].dt * 1E3 / n,
'inference': self.dt[1].dt * 1E3 / n,
'postprocess': self.dt[2].dt * 1E3 / n}
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \
else (path, im0s.copy())
p = Path(p)

@ -354,22 +354,6 @@ def get_git_branch():
return None # if not git dir or on error
def get_latest_pypi_version(package_name='ultralytics'):
"""
Returns the latest version of a PyPI package without downloading or installing it.
Parameters:
package_name (str): The name of the package to find the latest version for.
Returns:
str: The latest version of the package.
"""
response = requests.get(f'https://pypi.org/pypi/{package_name}/json')
if response.status_code == 200:
return response.json()['info']['version']
return None
def get_default_args(func):
"""Returns a dictionary of default arguments for a function.
@ -611,7 +595,7 @@ def set_settings(kwargs, file=USER_CONFIG_DIR / 'settings.yaml'):
# Run below code on yolo/utils init ------------------------------------------------------------------------------------
# Set logger
set_logging(LOGGING_NAME) # run before defining LOGGER
set_logging(LOGGING_NAME, verbose=VERBOSE) # run before defining LOGGER
LOGGER = logging.getLogger(LOGGING_NAME) # define globally (used in train.py, val.py, detect.py, etc.)
if WINDOWS:
for fn in LOGGER.info, LOGGER.warning:

@ -37,11 +37,7 @@ from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.torch_utils import select_device
def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
imgsz=640,
half=False,
device='cpu',
hard_fail=False):
def benchmark(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt', imgsz=160, half=False, device='cpu', hard_fail=0.30):
device = select_device(device, verbose=False)
if isinstance(model, (str, Path)):
model = YOLO(model)
@ -52,6 +48,7 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
try:
assert i not in (9, 10), 'inference not supported' # Edge TPU and TF.js are unsupported
assert i != 5 or platform.system() == 'Darwin', 'inference only supported on macOS>=10.13' # CoreML
assert i != 11 or model.task != 'classify', 'paddle-classify bug'
if 'cpu' in device.type:
assert cpu, 'inference not supported on CPU'
@ -85,26 +82,28 @@ def run_benchmarks(model=Path(SETTINGS['weights_dir']) / 'yolov8n.pt',
y.append([name, '', round(file_size(filename), 1), round(metric, 4), round(speed, 2)])
except Exception as e:
if hard_fail:
assert type(e) is AssertionError, f'Benchmark --hard-fail for {name}: {e}'
assert type(e) is AssertionError, f'Benchmark hard_fail for {name}: {e}'
LOGGER.warning(f'ERROR ❌ Benchmark failure for {name}: {e}')
y.append([name, '', None, None, None]) # mAP, t_inference
# Print results
LOGGER.info('\n')
check_yolo(device=device) # print system info
c = ['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)'] if map else ['Format', 'Export', '', '']
c = ['Format', 'Status❔', 'Size (MB)', key, 'Inference time (ms/im)']
df = pd.DataFrame(y, columns=c)
LOGGER.info(f'\nBenchmarks complete for {Path(model.ckpt_path).name} on {data} at imgsz={imgsz} '
f'({time.time() - t0:.2f}s)')
LOGGER.info(str(df if map else df.iloc[:, :2]))
if hard_fail and isinstance(hard_fail, str):
name = Path(model.ckpt_path).name
s = f'\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({time.time() - t0:.2f}s)\n{df}\n'
LOGGER.info(s)
with open('benchmarks.log', 'a') as f:
f.write(s)
if hard_fail and isinstance(hard_fail, float):
metrics = df[key].array # values to compare to floor
floor = eval(hard_fail) # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: metric < floor {floor}'
floor = hard_fail # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: one or more metric(s) < floor {floor}'
return df
if __name__ == '__main__':
run_benchmarks()
benchmark()

@ -16,6 +16,7 @@ import cv2
import numpy as np
import pkg_resources as pkg
import psutil
import requests
import torch
from matplotlib import font_manager
@ -117,6 +118,31 @@ def check_version(current: str = '0.0.0',
return result
def check_latest_pypi_version(package_name='ultralytics'):
"""
Returns the latest version of a PyPI package without downloading or installing it.
Parameters:
package_name (str): The name of the package to find the latest version for.
Returns:
str: The latest version of the package.
"""
response = requests.get(f'https://pypi.org/pypi/{package_name}/json')
if response.status_code == 200:
return response.json()['info']['version']
return None
def check_pip_update():
from ultralytics import __version__
latest = check_latest_pypi_version()
latest = '9.0.0'
if pkg.parse_version(__version__) < pkg.parse_version(latest):
LOGGER.info(f'New https://pypi.org/project/ultralytics/{latest} available 😃 '
f"Update with 'pip install -U ultralytics'")
def check_font(font='Arial.ttf'):
"""
Find font locally or download to user's configuration directory if it does not already exist.

@ -1,10 +1,12 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import os
import re
import shutil
import socket
import sys
import tempfile
from pathlib import Path
from . import USER_CONFIG_DIR
from .torch_utils import TORCH_1_9
@ -22,12 +24,12 @@ def find_free_network_port() -> int:
def generate_ddp_file(trainer):
import_path = '.'.join(str(trainer.__class__).split('.')[1:-1])
module, name = f'{trainer.__class__.__module__}.{trainer.__class__.__name__}'.rsplit('.', 1)
content = f'''cfg = {vars(trainer.args)} \nif __name__ == "__main__":
from ultralytics.{import_path} import {trainer.__class__.__name__}
from {module} import {name}
trainer = {trainer.__class__.__name__}(cfg=cfg)
trainer = {name}(cfg=cfg)
trainer.train()'''
(USER_CONFIG_DIR / 'DDP').mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(prefix='_temp_',
@ -41,12 +43,12 @@ def generate_ddp_file(trainer):
def generate_ddp_command(world_size, trainer):
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
file = os.path.abspath(sys.argv[0])
using_cli = not file.endswith('.py')
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir
if using_cli:
file = str(Path(sys.argv[0]).resolve())
safe_pattern = re.compile(r'^[a-zA-Z0-9_. /\\-]{1,128}$') # allowed characters and maximum of 100 characters
if not (safe_pattern.match(file) and Path(file).exists() and file.endswith('.py')): # using CLI
file = generate_ddp_file(trainer)
dist_cmd = 'torch.distributed.run' if TORCH_1_9 else 'torch.distributed.launch'
port = find_free_network_port()

Loading…
Cancel
Save