`ultralytics 8.0.48` Edge TPU fix and Metrics updates (#1171)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: majid nasiri <majnasai@gmail.com>
pull/1038/head^2 v8.0.48
Glenn Jocher 2 years ago committed by GitHub
parent a58f766f94
commit 74e4c94806
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 77
      .github/workflows/ci.yaml
  2. 23
      docs/predict.md
  3. 2
      examples/tutorial.ipynb
  4. 4
      tests/test_cli.py
  5. 4
      tests/test_python.py
  6. 2
      ultralytics/__init__.py
  7. 22
      ultralytics/hub/__init__.py
  8. 109
      ultralytics/hub/session.py
  9. 64
      ultralytics/hub/utils.py
  10. 8
      ultralytics/yolo/cfg/__init__.py
  11. 19
      ultralytics/yolo/engine/exporter.py
  12. 23
      ultralytics/yolo/engine/model.py
  13. 3
      ultralytics/yolo/engine/results.py
  14. 19
      ultralytics/yolo/utils/__init__.py
  15. 1
      ultralytics/yolo/utils/callbacks/base.py
  16. 35
      ultralytics/yolo/utils/callbacks/hub.py
  17. 7
      ultralytics/yolo/utils/callbacks/tensorboard.py
  18. 23
      ultralytics/yolo/utils/checks.py
  19. 9
      ultralytics/yolo/utils/downloads.py
  20. 185
      ultralytics/yolo/utils/metrics.py
  21. 6
      ultralytics/yolo/utils/plotting.py
  22. 3
      ultralytics/yolo/v8/classify/predict.py
  23. 3
      ultralytics/yolo/v8/classify/val.py

@ -12,6 +12,56 @@ on:
- cron: '0 0 * * *' # runs at 00:00 UTC every day
jobs:
HUB:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ['3.10']
model: [yolov5n]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Get cache dir # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
id: pip-cache
run: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash # for Windows compatibility
- name: Cache pip
uses: actions/cache@v3
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-pip-
- name: Install requirements
shell: bash # for Windows compatibility
run: |
python -m pip install --upgrade pip wheel
pip install -e . --extra-index-url https://download.pytorch.org/whl/cpu
- name: Check environment
run: |
echo "RUNNER_OS is ${{ runner.os }}"
echo "GITHUB_EVENT_NAME is ${{ github.event_name }}"
echo "GITHUB_WORKFLOW is ${{ github.workflow }}"
echo "GITHUB_ACTOR is ${{ github.actor }}"
echo "GITHUB_REPOSITORY is ${{ github.repository }}"
echo "GITHUB_REPOSITORY_OWNER is ${{ github.repository_owner }}"
python --version
pip --version
pip list
- name: Test HUB training
shell: python
env:
APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }}
run: |
import os
from ultralytics import hub
key = os.environ['APIKEY']
hub.reset_model(key)
hub.start(key)
Benchmarks:
runs-on: ${{ matrix.os }}
strategy:
@ -25,12 +75,16 @@ jobs:
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
#- name: Cache pip
# uses: actions/cache@v3
# with:
# path: ~/.cache/pip
# key: ${{ runner.os }}-Benchmarks-${{ hashFiles('requirements.txt') }}
# restore-keys: ${{ runner.os }}-Benchmarks-
- name: Get cache dir # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
id: pip-cache
run: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
shell: bash # for Windows compatibility
- name: Cache pip
uses: actions/cache@v3
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-${{ matrix.python-version }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-pip-
- name: Install requirements
shell: bash # for Windows compatibility
run: |
@ -120,17 +174,6 @@ jobs:
python --version
pip --version
pip list
- name: Test pip package
shell: python
env:
APIKEY: ${{ secrets.ULTRALYTICS_HUB_APIKEY }}
run: |
import os
import ultralytics
key = os.environ['APIKEY']
ultralytics.checks()
# ultralytics.reset_model(key) # reset trained model
# ultralytics.start(key) # train model
- name: Test detection
shell: bash # for Windows compatibility
run: |

@ -28,6 +28,29 @@ predictor's call method.
probs = r.probs # Class probabilities for classification outputs
```
## Sources
YOLOv8 can run inference on a variety of sources. The table below lists the various sources that can be used as input
for YOLOv8, along with the required format and notes. Sources include images, URLs, PIL images, OpenCV, numpy arrays,
torch tensors, CSV files, videos, directories, globs, YouTube videos, and streams. The table also indicates whether each
source can be used as a stream and the model argument required for that source.
| source | stream | model(arg) | type | notes |
|------------|---------|--------------------------------------------|----------------|------------------|
| image | | `'im.jpg'` | `str`, `Path` | |
| URL | | `'https://ultralytics.com/images/bus.jpg'` | `str` | |
| screenshot | | `'screen'` | `str` | |
| PIL | | `Image.open('im.jpg')` | `PIL.Image` | HWC, RGB |
| OpenCV | | `cv2.imread('im.jpg')[:,:,::-1]` | `np.ndarray` | HWC, BGR to RGB |
| numpy | | `np.zeros((640,1280,3))` | `np.ndarray` | HWC |
| torch | | `torch.zeros(16,3,320,640)` | `torch.Tensor` | BCHW, RGB |
| CSV | | `'sources.csv'` | `str`, `Path` | RTSP, RTMP, HTTP |
| video | &check; | `'vid.mp4'` | `str`, `Path` | |
| directory | &check; | `'path/'` | `str`, `Path` | |
| glob | &check; | `path/*.jpg'` | `str` | Use `*` operator |
| YouTube | &check; | `'https://youtu.be/Zgi9g1ksQHc'` | `str` | |
| stream | &check; | `'rtsp://example.com/media.mp4'` | `str` | RTSP, RTMP, HTTP |
## Working with Results
Results object consists of these component objects:

@ -645,7 +645,7 @@
"cell_type": "code",
"source": [
"# Git clone install (for development)\n",
"!git clone https://github.com/ultralytics/ultralytics\n",
"!git clone https://github.com/ultralytics/ultralytics -b main\n",
"%pip install -qe ultralytics"
],
"metadata": {

@ -3,7 +3,7 @@
import subprocess
from pathlib import Path
from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS, checks
from ultralytics.yolo.utils import LINUX, ONLINE, ROOT, SETTINGS
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n'
CFG = 'yolov8n'
@ -49,7 +49,7 @@ def test_val_classify():
# Predict checks -------------------------------------------------------------------------------------------------------
def test_predict_detect():
run(f"yolo predict model={MODEL}.pt source={ROOT / 'assets'} imgsz=32 save save_crop save_txt")
if checks.check_online():
if ONLINE:
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=32')
run(f'yolo predict model={MODEL}.pt source=https://ultralytics.com/assets/decelera_portrait_min.mov imgsz=32')

@ -9,7 +9,7 @@ from PIL import Image
from ultralytics import YOLO
from ultralytics.yolo.data.build import load_inference_source
from ultralytics.yolo.utils import LINUX, ROOT, SETTINGS, checks
from ultralytics.yolo.utils import LINUX, ONLINE, ROOT, SETTINGS
MODEL = Path(SETTINGS['weights_dir']) / 'yolov8n.pt'
CFG = 'yolov8n.yaml'
@ -58,7 +58,7 @@ def test_predict_img():
batch = [
str(SOURCE), # filename
Path(SOURCE), # Path
'https://ultralytics.com/images/zidane.jpg' if checks.check_online() else SOURCE, # URI
'https://ultralytics.com/images/zidane.jpg' if ONLINE else SOURCE, # URI
cv2.imread(str(SOURCE)), # OpenCV
Image.open(SOURCE), # PIL
np.zeros((320, 640, 3))] # numpy

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
__version__ = '8.0.47'
__version__ = '8.0.48'
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils.checks import check_yolo as checks

@ -3,11 +3,11 @@
import requests
from ultralytics.hub.auth import Auth
from ultralytics.hub.session import HubTrainingSession
from ultralytics.hub.utils import split_key
from ultralytics.hub.session import HUBTrainingSession
from ultralytics.hub.utils import PREFIX, split_key
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_LIST
from ultralytics.yolo.engine.model import YOLO
from ultralytics.yolo.utils import LOGGER, PREFIX, emojis
from ultralytics.yolo.utils import LOGGER, emojis
# Define all export formats
EXPORT_FORMATS_HUB = EXPORT_FORMATS_LIST + ['ultralytics_tflite', 'ultralytics_coreml']
@ -18,7 +18,6 @@ def start(key=''):
Start training models with Ultralytics HUB. Usage: from ultralytics.hub import start; start('API_KEY')
"""
auth = Auth(key)
try:
if not auth.get_state():
model_id = request_api_key(auth)
else:
@ -27,14 +26,11 @@ def start(key=''):
if not model_id:
raise ConnectionError(emojis('Connecting with global API key is not currently supported. ❌'))
session = HubTrainingSession(model_id=model_id, auth=auth)
session = HUBTrainingSession(model_id=model_id, auth=auth)
session.check_disk_space()
model = YOLO(session.input_file)
session.register_callbacks(model)
model = YOLO(model=session.model_file, session=session)
model.train(**session.train_args)
except Exception as e:
LOGGER.warning(f'{PREFIX}{e}')
def request_api_key(auth, max_attempts=3):
@ -62,9 +58,9 @@ def reset_model(key=''):
r = requests.post('https://api.ultralytics.com/model-reset', json={'apiKey': api_key, 'modelId': model_id})
if r.status_code == 200:
LOGGER.info(f'{PREFIX}model reset successfully')
LOGGER.info(f'{PREFIX}Model reset successfully')
return
LOGGER.warning(f'{PREFIX}model reset failure {r.status_code} {r.reason}')
LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
def export_model(key='', format='torchscript'):
@ -76,7 +72,7 @@ def export_model(key='', format='torchscript'):
'apiKey': api_key,
'modelId': model_id,
'format': format})
assert (r.status_code == 200), f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
LOGGER.info(f'{PREFIX}{format} export started ✅')
@ -89,7 +85,7 @@ def get_export(key='', format='torchscript'):
'apiKey': api_key,
'modelId': model_id,
'format': format})
assert (r.status_code == 200), f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
return r.json()

@ -1,30 +1,27 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
import json
import signal
import sys
from pathlib import Path
from time import sleep, time
from time import sleep
import requests
from ultralytics.hub.utils import HUB_API_ROOT, check_dataset_disk_space, smart_request
from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, emojis, is_colab, threaded
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
from ultralytics.yolo.utils import LOGGER, PREFIX, __version__, checks, emojis, is_colab, threaded
AGENT_NAME = f'python-{__version__}-colab' if is_colab() else f'python-{__version__}-local'
session = None
class HubTrainingSession:
class HUBTrainingSession:
def __init__(self, model_id, auth):
self.agent_id = None # identifies which instance is communicating with server
self.model_id = model_id
self.api_url = f'{HUB_API_ROOT}/v1/models/{model_id}'
self.auth_header = auth.get_auth_header()
self._rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
self._timers = {} # rate limit timers (seconds)
self._metrics_queue = {} # metrics queue
self.rate_limits = {'metrics': 3.0, 'ckpt': 900.0, 'heartbeat': 300.0} # rate limits (seconds)
self.timers = {} # rate limit timers (seconds)
self.metrics_queue = {} # metrics queue
self.model = self._get_model()
self.alive = True
self._start_heartbeat() # start heartbeats
@ -50,16 +47,15 @@ class HubTrainingSession:
self.alive = False
def upload_metrics(self):
payload = {'metrics': self._metrics_queue.copy(), 'type': 'metrics'}
smart_request(f'{self.api_url}', json=payload, headers=self.auth_header, code=2)
payload = {'metrics': self.metrics_queue.copy(), 'type': 'metrics'}
smart_request('post', self.api_url, json=payload, headers=self.auth_header, code=2)
def _get_model(self):
# Returns model from database by id
api_url = f'{HUB_API_ROOT}/v1/models/{self.model_id}'
headers = self.auth_header
try:
response = smart_request(api_url, method='get', headers=headers, thread=False, code=0)
response = smart_request('get', api_url, headers=self.auth_header, thread=False, code=0)
data = response.json().get('data', None)
if data.get('status', None) == 'trained':
@ -82,11 +78,8 @@ class HubTrainingSession:
'cache': data['cache'],
'data': data['data']}
self.input_file = data.get('cfg', data['weights'])
# hack for yolov5 cfg adds u
if 'cfg' in data and 'yolov5' in data['cfg']:
self.input_file = data['cfg'].replace('.yaml', 'u.yaml')
self.model_file = data.get('cfg', data['weights'])
self.model_file = checks.check_yolov5u_filename(self.model_file, verbose=False) # YOLOv5->YOLOv5u
return data
except requests.exceptions.ConnectionError as e:
@ -98,86 +91,44 @@ class HubTrainingSession:
if not check_dataset_disk_space(self.model['data']):
raise MemoryError('Not enough disk space')
def register_callbacks(self, trainer):
trainer.add_callback('on_pretrain_routine_end', self.on_pretrain_routine_end)
trainer.add_callback('on_fit_epoch_end', self.on_fit_epoch_end)
trainer.add_callback('on_model_save', self.on_model_save)
trainer.add_callback('on_train_end', self.on_train_end)
def on_pretrain_routine_end(self, trainer):
"""
Start timer for upload rate limit.
This method does not use trainer. It is passed to all callbacks by default.
"""
# Start timer for upload rate limit
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
self._timers = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit
def on_fit_epoch_end(self, trainer):
# Upload metrics after val end
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
if trainer.epoch == 0:
model_info = {
'model/parameters': get_num_params(trainer.model),
'model/GFLOPs': round(get_flops(trainer.model), 3),
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
all_plots = {**all_plots, **model_info}
self._metrics_queue[trainer.epoch] = json.dumps(all_plots)
if time() - self._timers['metrics'] > self._rate_limits['metrics']:
self.upload_metrics()
self._timers['metrics'] = time() # reset timer
self._metrics_queue = {} # reset queue
def on_model_save(self, trainer):
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - self._timers['ckpt'] > self._rate_limits['ckpt']:
LOGGER.info(f'{PREFIX}Uploading checkpoint {self.model_id}')
self._upload_model(trainer.epoch, trainer.last, is_best)
self._timers['ckpt'] = time() # reset timer
def on_train_end(self, trainer):
# Upload final model and metrics with exponential standoff
LOGGER.info(f'{PREFIX}Training completed successfully ✅\n'
f'{PREFIX}Uploading final {self.model_id}')
self._upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
self.alive = False # stop heartbeats
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{self.model_id} 🚀')
def _upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
def upload_model(self, epoch, weights, is_best=False, map=0.0, final=False):
# Upload a model to HUB
if Path(weights).is_file():
with open(weights, 'rb') as f:
file = f.read()
else:
LOGGER.warning(f'{PREFIX}WARNING ⚠ Model upload failed. Missing model {weights}.')
LOGGER.warning(f'{PREFIX}WARNING ⚠ Model upload issue. Missing model {weights}.')
file = None
url = f'{self.api_url}/upload'
# url = 'http://httpbin.org/post' # for debug
data = {'epoch': epoch}
if final:
data.update({'type': 'final', 'map': map})
else:
data.update({'type': 'epoch', 'isBest': bool(is_best)})
smart_request(f'{self.api_url}/upload',
smart_request('post',
url,
data=data,
files={'best.pt' if final else 'last.pt': file},
files={'best.pt': file},
headers=self.auth_header,
retry=10 if final else None,
timeout=3600 if final else None,
code=4 if final else 3)
retry=10,
timeout=3600,
thread=False,
progress=True,
code=4)
else:
data.update({'type': 'epoch', 'isBest': bool(is_best)})
smart_request('post', url, data=data, files={'last.pt': file}, headers=self.auth_header, code=3)
@threaded
def _start_heartbeat(self):
while self.alive:
r = smart_request(f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
r = smart_request('post',
f'{HUB_API_ROOT}/v1/agent/heartbeat/models/{self.model_id}',
json={
'agent': AGENT_NAME,
'agentId': self.agent_id},
headers=self.auth_header,
retry=0,
code=5,
thread=False)
thread=False) # already in a thread
self.agent_id = r.json().get('data', {}).get('agentId', None)
sleep(self._rate_limits['heartbeat'])
sleep(self.rate_limits['heartbeat'])

@ -10,13 +10,13 @@ from pathlib import Path
from random import random
import requests
from tqdm import tqdm
from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, ENVIRONMENT, LOGGER, RANK, SETTINGS, TESTS_RUNNING, TryExcept,
__version__, colorstr, emojis, get_git_origin_url, is_colab, is_git_dir,
is_pip_package)
from ultralytics.yolo.utils.checks import check_online
from ultralytics.yolo.utils import (DEFAULT_CFG_DICT, ENVIRONMENT, LOGGER, ONLINE, RANK, SETTINGS, TESTS_RUNNING,
TQDM_BAR_FORMAT, TryExcept, __version__, colorstr, emojis, get_git_origin_url,
is_colab, is_git_dir, is_pip_package)
PREFIX = colorstr('Ultralytics: ')
PREFIX = colorstr('Ultralytics HUB: ')
HELP_MSG = 'If this issue persists please visit https://github.com/ultralytics/hub/issues for assistance.'
HUB_API_ROOT = os.environ.get('ULTRALYTICS_HUB_API', 'https://api.ultralytics.com')
@ -60,7 +60,6 @@ def request_with_credentials(url: str) -> any:
return output.eval_js('_hub_tmp')
# Deprecated TODO: eliminate this function?
def split_key(key=''):
"""
Verify and split a 'api_key[sep]model_id' string, sep is one of '.' or '_'
@ -84,36 +83,61 @@ def split_key(key=''):
return api_key, model_id
def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method='post', verbose=True, **kwargs):
def requests_with_progress(method, url, **kwargs):
"""
Make an HTTP request using the specified method and URL, with an optional progress bar.
Args:
method (str): The HTTP method to use (e.g. 'GET', 'POST').
url (str): The URL to send the request to.
progress (bool, optional): Whether to display a progress bar. Defaults to False.
**kwargs: Additional keyword arguments to pass to the underlying `requests.request` function.
Returns:
requests.Response: The response from the HTTP request.
"""
progress = kwargs.pop('progress', False)
if not progress:
return requests.request(method, url, **kwargs)
response = requests.request(method, url, stream=True, **kwargs)
total = int(response.headers.get('content-length', 0)) # total size
pbar = tqdm(total=total, unit='B', unit_scale=True, unit_divisor=1024, bar_format=TQDM_BAR_FORMAT)
for data in response.iter_content(chunk_size=1024):
pbar.update(len(data))
pbar.close()
return response
def smart_request(method, url, retry=3, timeout=30, thread=True, code=-1, verbose=True, progress=False, **kwargs):
"""
Makes an HTTP request using the 'requests' library, with exponential backoff retries up to a specified timeout.
Args:
*args: Positional arguments to be passed to the requests function specified in method.
method (str): The HTTP method to use for the request. Choices are 'post' and 'get'.
url (str): The URL to make the request to.
retry (int, optional): Number of retries to attempt before giving up. Default is 3.
timeout (int, optional): Timeout in seconds after which the function will give up retrying. Default is 30.
thread (bool, optional): Whether to execute the request in a separate daemon thread. Default is True.
code (int, optional): An identifier for the request, used for logging purposes. Default is -1.
method (str, optional): The HTTP method to use for the request. Choices are 'post' and 'get'. Default is 'post'.
verbose (bool, optional): A flag to determine whether to print out to console or not. Default is True.
progress (bool, optional): Whether to show a progress bar during the request. Default is False.
**kwargs: Keyword arguments to be passed to the requests function specified in method.
Returns:
requests.Response: The HTTP response object. If the request is executed in a separate thread, returns None.
"""
retry_codes = (408, 500) # retry only these codes
@TryExcept(verbose=verbose)
def func(*func_args, **func_kwargs):
def func(func_method, func_url, **func_kwargs):
r = None # response
t0 = time.time() # initial time for timer
for i in range(retry + 1):
if (time.time() - t0) > timeout:
break
if method == 'post':
r = requests.post(*func_args, **func_kwargs) # i.e. post(url, data, json, files)
elif method == 'get':
r = requests.get(*func_args, **func_kwargs) # i.e. get(url, data, json, files)
r = requests_with_progress(func_method, func_url, **func_kwargs) # i.e. get(url, data, json, files)
if r.status_code == 200:
break
try:
@ -134,6 +158,8 @@ def smart_request(*args, retry=3, timeout=30, thread=True, code=-1, method='post
time.sleep(2 ** i) # exponential standoff
return r
args = method, url
kwargs['progress'] = progress
if thread:
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
else:
@ -157,8 +183,8 @@ class Traces:
self.enabled = \
SETTINGS['sync'] and \
RANK in {-1, 0} and \
check_online() and \
not TESTS_RUNNING and \
ONLINE and \
(is_pip_package() or get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git')
def __call__(self, cfg, all_keys=False, traces_sample_rate=1.0):
@ -182,13 +208,7 @@ class Traces:
trace = {'uuid': SETTINGS['uuid'], 'cfg': cfg, 'metadata': self.metadata}
# Send a request to the HUB API to sync analytics
smart_request(f'{HUB_API_ROOT}/v1/usage/anonymous',
json=trace,
headers=None,
code=3,
retry=0,
timeout=1.0,
verbose=False)
smart_request('post', f'{HUB_API_ROOT}/v1/usage/anonymous', json=trace, code=3, retry=0, verbose=False)
# Run below code on hub/utils init -------------------------------------------------------------------------------------

@ -13,7 +13,7 @@ from ultralytics.yolo.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_P
CLI_HELP_MSG = \
f"""
Arguments received: {str(['yolo'] + sys.argv[1:])}. Note that Ultralytics 'yolo' commands use the following syntax:
Arguments received: {str(['yolo'] + sys.argv[1:])}. Ultralytics 'yolo' commands use the following syntax:
yolo TASK MODE ARGS
@ -217,6 +217,9 @@ def entrypoint(debug=''):
if a.startswith('--'):
LOGGER.warning(f"WARNING ⚠ '{a}' does not require leading dashes '--', updating to '{a[2:]}'.")
a = a[2:]
if a.endswith(','):
LOGGER.warning(f"WARNING ⚠ '{a}' does not require trailing comma ',', updating to '{a[:-1]}'.")
a = a[:-1]
if '=' in a:
try:
re.sub(r' *= *', '=', a) # remove spaces around equals sign
@ -284,6 +287,9 @@ def entrypoint(debug=''):
model = YOLO(model, task=task)
# Task Update
if task and task != model.task:
LOGGER.warning(f"WARNING ⚠ conflicting 'task={task}' passed with 'task={model.task}' model. "
f'This may produce errors.')
task = task or model.task
overrides['task'] = task

@ -243,15 +243,12 @@ class Exporter:
if coreml: # CoreML
f[4], _ = self._export_coreml()
if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
LOGGER.warning('WARNING ⚠ YOLOv8 TensorFlow export is still under development. '
'Please consider contributing to the effort if you have TF expertise. Thank you!')
nms = False
self.args.int8 |= edgetpu
f[5], s_model = self._export_saved_model()
if pb or tfjs: # pb prerequisite to tfjs
f[6], _ = self._export_pb(s_model)
if tflite:
f[7], _ = self._export_tflite(s_model, nms=nms, agnostic_nms=self.args.agnostic_nms)
f[7], _ = self._export_tflite(s_model, nms=False, agnostic_nms=self.args.agnostic_nms)
if edgetpu:
f[8], _ = self._export_edgetpu(tflite_model=str(
Path(f[5]) / (self.file.stem + '_full_integer_quant.tflite'))) # int8 in/out
@ -619,20 +616,18 @@ class Exporter:
@try_export
def _export_edgetpu(self, tflite_model='', prefix=colorstr('Edge TPU:')):
# YOLOv8 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
LOGGER.warning(f'{prefix} WARNING ⚠ Edge TPU known bug https://github.com/ultralytics/ultralytics/issues/1185')
cmd = 'edgetpu_compiler --version'
help_url = 'https://coral.ai/docs/edgetpu/compiler/'
assert LINUX, f'export only supported on Linux. See {help_url}'
if subprocess.run(f'{cmd} > /dev/null', shell=True).returncode != 0:
if subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True).returncode != 0:
LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
for c in (
# 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', # errors
'wget --no-check-certificate -q -O - https://packages.cloud.google.com/apt/doc/apt-key.gpg | '
'sudo apt-key add -',
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | ' # no comma
'sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
'sudo apt-get update',
'sudo apt-get install edgetpu-compiler'):
'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]

@ -43,7 +43,7 @@ class YOLO:
cfg (str): The model configuration if loaded from *.yaml file.
ckpt_path (str): The checkpoint file path.
overrides (dict): Overrides for the trainer object.
metrics_data (Any): The data for metrics.
metrics (Any): The data for metrics.
Methods:
__call__(source=None, stream=False, **kwargs):
@ -67,7 +67,7 @@ class YOLO:
list(ultralytics.yolo.engine.results.Results): The prediction results.
"""
def __init__(self, model='yolov8n.pt', task=None) -> None:
def __init__(self, model='yolov8n.pt', task=None, session=None) -> None:
"""
Initializes the YOLO model.
@ -83,7 +83,8 @@ class YOLO:
self.cfg = None # if loaded from *.yaml
self.ckpt_path = None
self.overrides = {} # overrides for trainer object
self.metrics_data = None
self.metrics = None # validation/training metrics
self.session = session # HUB session
# Load or create new YOLO model
suffix = Path(model).suffix
@ -184,6 +185,7 @@ class YOLO:
self._check_is_pytorch_model()
self.model.fuse()
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
@ -217,7 +219,6 @@ class YOLO:
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
@smart_inference_mode()
def track(self, source=None, stream=False, **kwargs):
from ultralytics.tracker import register_tracker
register_tracker(self)
@ -252,7 +253,7 @@ class YOLO:
validator = TASK_MAP[self.task][2](args=args)
validator(model=self.model)
self.metrics_data = validator.metrics
self.metrics = validator.metrics
return validator.metrics
@ -314,12 +315,13 @@ class YOLO:
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.hub_session = self.session # attach optional HUB session
self.trainer.train()
# update model and cfg after training
if RANK in {0, -1}:
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics_data = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
def to(self, device):
"""
@ -352,15 +354,6 @@ class YOLO:
"""
return self.model.transforms if hasattr(self.model, 'transforms') else None
@property
def metrics(self):
"""
Returns metrics if computed
"""
if not self.metrics_data:
LOGGER.info('No metrics data found! Run training or validation operation first.')
return self.metrics_data
@staticmethod
def add_callback(event: str, func):
"""

@ -139,7 +139,8 @@ class Results:
annotator.masks(masks.data, colors=[colors(x, True) for x in boxes.cls], im_gpu=im)
if logits is not None:
top5i = logits.argsort(0, descending=True)[:5].tolist() # top 5 indices
n5 = min(len(self.names), 5)
top5i = logits.argsort(0, descending=True)[:n5].tolist() # top 5 indices
text = f"{', '.join(f'{names[j] if names else j} {logits[j]:.2f}' for j in top5i)}, "
annotator.text((32, 32), text, txt_color=(255, 255, 255)) # TODO: allow setting colors

@ -243,6 +243,24 @@ def is_docker() -> bool:
return False
def is_online() -> bool:
"""
Check internet connectivity by attempting to connect to a known online host.
Returns:
bool: True if connection is successful, False otherwise.
"""
import socket
with contextlib.suppress(Exception):
host = socket.gethostbyname('www.github.com')
socket.create_connection((host, 80), timeout=2)
return True
return False
ONLINE = is_online()
def is_pip_package(filepath: str = __name__) -> bool:
"""
Determines if the file at the given filepath is part of a pip package.
@ -513,6 +531,7 @@ def set_sentry():
RANK in {-1, 0} and \
Path(sys.argv[0]).name == 'yolo' and \
not TESTS_RUNNING and \
ONLINE and \
((is_pip_package() and not is_git_dir()) or
(get_git_origin_url() == 'https://github.com/ultralytics/ultralytics.git' and get_git_branch() == 'main')):

@ -151,4 +151,5 @@ def add_integration_callbacks(instance):
for x in clearml_callbacks, comet_callbacks, hub_callbacks, tb_callbacks:
for k, v in x.items():
if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
instance.callbacks[k].append(v) # callback[name].append(func)

@ -4,24 +4,33 @@ import json
from time import time
from ultralytics.hub.utils import PREFIX, traces
from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
from ultralytics.yolo.utils import LOGGER
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params
def on_pretrain_routine_end(trainer):
session = not TESTS_RUNNING and getattr(trainer, 'hub_session', None)
session = getattr(trainer, 'hub_session', None)
if session:
# Start timer for upload rate limit
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
session.t = {'metrics': time(), 'ckpt': time()} # start timer on self.rate_limit
session.timers = {'metrics': time(), 'ckpt': time()} # start timer on session.rate_limit
def on_fit_epoch_end(trainer):
session = getattr(trainer, 'hub_session', None)
if session:
session.metrics_queue[trainer.epoch] = json.dumps(trainer.metrics) # json string
if time() - session.t['metrics'] > session.rate_limits['metrics']:
# Upload metrics after val end
all_plots = {**trainer.label_loss_items(trainer.tloss, prefix='train'), **trainer.metrics}
if trainer.epoch == 0:
model_info = {
'model/parameters': get_num_params(trainer.model),
'model/GFLOPs': round(get_flops(trainer.model), 3),
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)}
all_plots = {**all_plots, **model_info}
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
if time() - session.timers['metrics'] > session.rate_limits['metrics']:
session.upload_metrics()
session.t['metrics'] = time() # reset timer
session.timers['metrics'] = time() # reset timer
session.metrics_queue = {} # reset queue
@ -30,21 +39,21 @@ def on_model_save(trainer):
if session:
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.t['ckpt'] > session.rate_limits['ckpt']:
if time() - session.timers['ckpt'] > session.rate_limits['ckpt']:
LOGGER.info(f'{PREFIX}Uploading checkpoint {session.model_id}')
session.upload_model(trainer.epoch, trainer.last, is_best)
session.t['ckpt'] = time() # reset timer
session.timers['ckpt'] = time() # reset timer
def on_train_end(trainer):
session = getattr(trainer, 'hub_session', None)
if session:
# Upload final model and metrics with exponential standoff
LOGGER.info(f'{PREFIX}Training completed successfully ✅\n'
f'{PREFIX}Uploading final {session.model_id}')
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics['metrics/mAP50-95(B)'], final=True)
session.shutdown() # stop heartbeats
LOGGER.info(f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
LOGGER.info(f'{PREFIX}Syncing final model...')
session.upload_model(trainer.epoch, trainer.best, map=trainer.metrics.get('metrics/mAP50-95(B)', 0), final=True)
session.alive = False # stop heartbeats
LOGGER.info(f'{PREFIX}Done ✅\n'
f'{PREFIX}View model at https://hub.ultralytics.com/models/{session.model_id} 🚀')
def on_train_start(trainer):

@ -1,8 +1,12 @@
# Ultralytics YOLO 🚀, GPL-3.0 license
from ultralytics.yolo.utils import LOGGER, TESTS_RUNNING
try:
from torch.utils.tensorboard import SummaryWriter
from ultralytics.yolo.utils import LOGGER
assert not TESTS_RUNNING # do not log pytest
except (ImportError, AssertionError):
SummaryWriter = None
writer = None # TensorBoard SummaryWriter instance
@ -18,7 +22,6 @@ def on_pretrain_routine_start(trainer):
try:
writer = SummaryWriter(str(trainer.save_dir))
except Exception as e:
writer = None # TensorBoard SummaryWriter instance
LOGGER.warning(f'WARNING ⚠ TensorBoard not initialized correctly, not logging this run. {e}')

@ -21,7 +21,7 @@ import torch
from matplotlib import font_manager
from ultralytics.yolo.utils import (AUTOINSTALL, LOGGER, ROOT, USER_CONFIG_DIR, TryExcept, colorstr, downloads, emojis,
is_colab, is_docker, is_jupyter)
is_colab, is_docker, is_jupyter, is_online)
def is_ascii(s) -> bool:
@ -171,21 +171,6 @@ def check_font(font='Arial.ttf'):
return file
def check_online() -> bool:
"""
Check internet connectivity by attempting to connect to a known online host.
Returns:
bool: True if connection is successful, False otherwise.
"""
import socket
with contextlib.suppress(Exception):
host = socket.gethostbyname('www.github.com')
socket.create_connection((host, 80), timeout=2)
return True
return False
def check_python(minimum: str = '3.7.0') -> bool:
"""
Check current python version against the required minimum version.
@ -229,7 +214,7 @@ def check_requirements(requirements=ROOT.parent / 'requirements.txt', exclude=()
if s and install and AUTOINSTALL: # check environment variable
LOGGER.info(f"{prefix} YOLOv8 requirement{'s' * (n > 1)} {s}not found, attempting AutoUpdate...")
try:
assert check_online(), 'AutoUpdate skipped (offline)'
assert is_online(), 'AutoUpdate skipped (offline)'
LOGGER.info(subprocess.check_output(f'pip install {s} {cmds}', shell=True).decode())
s = f"{prefix} {n} package{'s' * (n > 1)} updated per {file or requirements}\n" \
f"{prefix} {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
@ -249,13 +234,13 @@ def check_suffix(file='yolov8n.pt', suffix='.pt', msg=''):
assert s in suffix, f'{msg}{f} acceptable suffix is {suffix}'
def check_yolov5u_filename(file: str):
def check_yolov5u_filename(file: str, verbose: bool = True):
# Replace legacy YOLOv5 filenames with updated YOLOv5u filenames
if 'yolov3' in file or 'yolov5' in file and 'u' not in file:
original_file = file
file = re.sub(r'(.*yolov5([nsmlx]))\.', '\\1u.', file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r'(.*yolov3(|-tiny|-spp))\.', '\\1u.', file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file:
if file != original_file and verbose:
LOGGER.info(f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
f'trained with https://github.com/ultralytics/ultralytics and feature improved performance vs '
f'standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n')

@ -12,7 +12,7 @@ import requests
import torch
from tqdm import tqdm
from ultralytics.yolo.utils import LOGGER, checks
from ultralytics.yolo.utils import LOGGER, checks, is_online
GITHUB_ASSET_NAMES = [f'yolov8{size}{suffix}.pt' for size in 'nsmlx' for suffix in ('', '6', '-cls', '-seg')] + \
[f'yolov5{size}u.pt' for size in 'nsmlx'] + \
@ -112,7 +112,7 @@ def safe_download(url,
break # success
f.unlink() # remove partial downloads
except Exception as e:
if i == 0 and not checks.check_online():
if i == 0 and not is_online():
raise ConnectionError(f'❌ Download failure for {url}. Environment is not online.') from e
elif i >= retry:
raise ConnectionError(f'❌ Download failure for {url}. Retry limit reached.') from e
@ -134,8 +134,7 @@ def safe_download(url,
def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
# Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
from ultralytics.yolo.utils import SETTINGS
from ultralytics.yolo.utils.checks import check_yolov5u_filename
from ultralytics.yolo.utils import SETTINGS # scoped for circular import
def github_assets(repository, version='latest'):
# Return GitHub repo tag and assets (i.e. ['yolov8n.pt', 'yolov8s.pt', ...])
@ -146,7 +145,7 @@ def attempt_download_asset(file, repo='ultralytics/assets', release='v0.0.0'):
# YOLOv3/5u updates
file = str(file)
file = check_yolov5u_filename(file)
file = checks.check_yolov5u_filename(file)
file = Path(file.strip().replace("'", ''))
if file.exists():
return str(file)

@ -43,16 +43,18 @@ def bbox_ioa(box1, box2, eps=1e-7):
def box_iou(box1, box2, eps=1e-7):
# https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
Arguments:
box1 (Tensor[N, 4])
box2 (Tensor[M, 4])
eps
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
"""
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
@ -109,7 +111,7 @@ def mask_iou(mask1, mask2, eps=1e-7):
mask1: [N, n] m1 means number of predicted objects
mask2: [M, n] m2 means number of gt objects
Note: n means image_w x image_h
return: masks iou, [N, M]
Returns: masks iou, [N, M]
"""
intersection = torch.matmul(mask1, mask2.t()).clamp(0)
union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
@ -121,7 +123,7 @@ def masks_iou(mask1, mask2, eps=1e-7):
mask1: [N, n] m1 means number of predicted objects
mask2: [N, n] m2 means number of gt objects
Note: n means image_w x image_h
return: masks iou, (N, )
Returns: masks iou, (N, )
"""
intersection = (mask1 * mask2).sum(1).clamp(0) # (N, )
union = (mask1.sum(1) + mask2.sum(1))[None] - intersection # (area1 + area2) - intersection
@ -317,10 +319,10 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
def compute_ap(recall, precision):
""" Compute the average precision, given the recall and precision curves
# Arguments
Arguments:
recall: The recall curve (list)
precision: The precision curve (list)
# Returns
Returns:
Average precision, precision curve, recall curve
"""
@ -344,17 +346,30 @@ def compute_ap(recall, precision):
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), names=(), eps=1e-16, prefix=''):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
tp: True positives (nparray, nx1 or nx10).
conf: Objectness value from 0-1 (nparray).
pred_cls: Predicted object classes (nparray).
target_cls: True object classes (nparray).
plot: Plot precision-recall curve at mAP@0.5
save_dir: Plot save directory
# Returns
The average precision as computed in py-faster-rcnn.
"""
Computes the average precision per class for object detection evaluation.
Args:
tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
conf (np.ndarray): Array of confidence scores of the detections.
pred_cls (np.ndarray): Array of predicted classes of the detections.
target_cls (np.ndarray): Array of true classes of the detections.
plot (bool, optional): Whether to plot PR curves or not. Defaults to False.
save_dir (Path, optional): Directory to save the PR curves. Defaults to an empty path.
names (tuple, optional): Tuple of class names to plot PR curves. Defaults to an empty tuple.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-16.
prefix (str, optional): A prefix string for saving the plot files. Defaults to an empty string.
Returns:
(tuple): A tuple of six arrays and one array of unique classes, where:
tp (np.ndarray): True positive counts for each class.
fp (np.ndarray): False positive counts for each class.
p (np.ndarray): Precision values at each confidence threshold.
r (np.ndarray): Recall values at each confidence threshold.
f1 (np.ndarray): F1-score values at each confidence threshold.
ap (np.ndarray): Average precision for each class at different IoU thresholds.
unique_classes (np.ndarray): An array of unique classes that have data.
"""
# Sort by objectness
@ -411,6 +426,32 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir=Path(), na
class Metric:
"""
Class for computing evaluation metrics for YOLOv8 model.
Attributes:
p (list): Precision for each class. Shape: (nc,).
r (list): Recall for each class. Shape: (nc,).
f1 (list): F1 score for each class. Shape: (nc,).
all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
ap_class_index (list): Index of class for each AP score. Shape: (nc,).
nc (int): Number of classes.
Methods:
ap50(): AP at IoU threshold of 0.5 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
ap(): AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: List of AP scores. Shape: (nc,) or [].
mp(): Mean precision of all classes. Returns: Float.
mr(): Mean recall of all classes. Returns: Float.
map50(): Mean AP at IoU threshold of 0.5 for all classes. Returns: Float.
map75(): Mean AP at IoU threshold of 0.75 for all classes. Returns: Float.
map(): Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. Returns: Float.
mean_results(): Mean of results, returns mp, mr, map50, map.
class_result(i): Class-aware result, returns p[i], r[i], ap50[i], ap[i].
maps(): mAP of each class. Returns: Array of mAP scores, shape: (nc,).
fitness(): Model fitness as a weighted combination of metrics. Returns: Float.
update(results): Update metric attributes with new evaluation results.
"""
def __init__(self) -> None:
self.p = [] # (nc, )
@ -420,10 +461,14 @@ class Metric:
self.ap_class_index = [] # (nc, )
self.nc = 0
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
@property
def ap50(self):
"""AP@0.5 of all classes.
Return:
Returns:
(nc, ) or [].
"""
return self.all_ap[:, 0] if len(self.all_ap) else []
@ -431,7 +476,7 @@ class Metric:
@property
def ap(self):
"""AP@0.5:0.95
Return:
Returns:
(nc, ) or [].
"""
return self.all_ap.mean(1) if len(self.all_ap) else []
@ -439,7 +484,7 @@ class Metric:
@property
def mp(self):
"""mean precision of all classes.
Return:
Returns:
float.
"""
return self.p.mean() if len(self.p) else 0.0
@ -447,7 +492,7 @@ class Metric:
@property
def mr(self):
"""mean recall of all classes.
Return:
Returns:
float.
"""
return self.r.mean() if len(self.r) else 0.0
@ -455,7 +500,7 @@ class Metric:
@property
def map50(self):
"""Mean AP@0.5 of all classes.
Return:
Returns:
float.
"""
return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
@ -463,7 +508,7 @@ class Metric:
@property
def map75(self):
"""Mean AP@0.75 of all classes.
Return:
Returns:
float.
"""
return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
@ -471,7 +516,7 @@ class Metric:
@property
def map(self):
"""Mean AP@0.5:0.95 of all classes.
Return:
Returns:
float.
"""
return self.all_ap.mean() if len(self.all_ap) else 0.0
@ -506,6 +551,32 @@ class Metric:
class DetMetrics:
"""
This class is a utility class for computing detection metrics such as precision, recall, and mean average precision
(mAP) of an object detection model.
Args:
save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
names (tuple of str): A tuple of strings that represents the names of the classes. Defaults to an empty tuple.
Attributes:
save_dir (Path): A path to the directory where the output plots will be saved.
plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
names (tuple of str): A tuple of strings that represents the names of the classes.
box (Metric): An instance of the Metric class for storing the results of the detection metrics.
speed (dict): A dictionary for storing the execution time of different parts of the detection process.
Methods:
process(tp, conf, pred_cls, target_cls): Updates the metric results with the latest batch of predictions.
keys: Returns a list of keys for accessing the computed detection metrics.
mean_results: Returns a list of mean values for the computed detection metrics.
class_result(i): Returns a list of values for the computed detection metrics for a specific class.
maps: Returns a dictionary of mean average precision (mAP) values for different IoU thresholds.
fitness: Computes the fitness score based on the computed detection metrics.
ap_class_index: Returns a list of class indices sorted by their average precision (AP) values.
results_dict: Returns a dictionary that maps detection metric keys to their computed values.
"""
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
self.save_dir = save_dir
@ -514,6 +585,10 @@ class DetMetrics:
self.box = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, tp, conf, pred_cls, target_cls):
results = ap_per_class(tp, conf, pred_cls, target_cls, plot=self.plot, save_dir=self.save_dir,
names=self.names)[2:]
@ -548,6 +623,31 @@ class DetMetrics:
class SegmentMetrics:
"""
Calculates and aggregates detection and segmentation metrics over a given set of classes.
Args:
save_dir (Path): Path to the directory where the output plots should be saved. Default is the current directory.
plot (bool): Whether to save the detection and segmentation plots. Default is False.
names (list): List of class names. Default is an empty list.
Attributes:
save_dir (Path): Path to the directory where the output plots should be saved.
plot (bool): Whether to save the detection and segmentation plots.
names (list): List of class names.
box (Metric): An instance of the Metric class to calculate box detection metrics.
seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
speed (dict): Dictionary to store the time taken in different phases of inference.
Methods:
process(tp_m, tp_b, conf, pred_cls, target_cls): Processes metrics over the given set of predictions.
mean_results(): Returns the mean of the detection and segmentation metrics over all the classes.
class_result(i): Returns the detection and segmentation metrics of class `i`.
maps: Returns the mean Average Precision (mAP) scores for IoU thresholds ranging from 0.50 to 0.95.
fitness: Returns the fitness scores, which are a single weighted combination of metrics.
ap_class_index: Returns the list of indices of classes used to compute Average Precision (AP).
results_dict: Returns the dictionary containing all the detection and segmentation metrics and fitness score.
"""
def __init__(self, save_dir=Path('.'), plot=False, names=()) -> None:
self.save_dir = save_dir
@ -557,7 +657,22 @@ class SegmentMetrics:
self.seg = Metric()
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, tp_m, tp_b, conf, pred_cls, target_cls):
"""
Processes the detection and segmentation metrics over the given set of predictions.
Args:
tp_m (list): List of True Positive masks.
tp_b (list): List of True Positive boxes.
conf (list): List of confidence scores.
pred_cls (list): List of predicted classes.
target_cls (list): List of target classes.
"""
results_mask = ap_per_class(tp_m,
conf,
pred_cls,
@ -610,12 +725,32 @@ class SegmentMetrics:
class ClassifyMetrics:
"""
Class for computing classification metrics including top-1 and top-5 accuracy.
Attributes:
top1 (float): The top-1 accuracy.
top5 (float): The top-5 accuracy.
speed (Dict[str, float]): A dictionary containing the time taken for each step in the pipeline.
Properties:
fitness (float): The fitness of the model, which is equal to top-5 accuracy.
results_dict (Dict[str, Union[float, str]]): A dictionary containing the classification metrics and fitness.
keys (List[str]): A list of keys for the results_dict.
Methods:
process(targets, pred): Processes the targets and predictions to compute classification metrics.
"""
def __init__(self) -> None:
self.top1 = 0
self.top5 = 0
self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
def __getattr__(self, attr):
name = self.__class__.__name__
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
def process(self, targets, pred):
# target classes and predicted classes
pred, targets = torch.cat(pred), torch.cat(targets)

@ -301,14 +301,14 @@ def plot_images(images,
# Plot masks
if len(masks):
if masks.max() > 1.0: # mean that masks are overlap
if idx.shape[0] == masks.shape[0]: # overlap_masks=False
image_masks = masks[idx]
else: # overlap_masks=True
image_masks = masks[[i]] # (1, 640, 640)
nl = idx.sum()
index = np.arange(nl).reshape(nl, 1, 1) + 1
image_masks = np.repeat(image_masks, nl, axis=0)
image_masks = np.where(image_masks == index, 1.0, 0.0)
else:
image_masks = masks[idx]
im = np.asarray(annotator.im).copy()
for j, box in enumerate(boxes.T.tolist()):

@ -52,7 +52,8 @@ class ClassificationPredictor(BasePredictor):
return log_string
prob = result.probs
# Print results
top5i = prob.argsort(0, descending=True)[:5].tolist() # top 5 indices
n5 = min(len(self.model.names), 5)
top5i = prob.argsort(0, descending=True)[:n5].tolist() # top 5 indices
log_string += f"{', '.join(f'{self.model.names[j]} {prob[j]:.2f}' for j in top5i)}, "
# write

@ -27,7 +27,8 @@ class ClassificationValidator(BaseValidator):
return batch
def update_metrics(self, preds, batch):
self.pred.append(preds.argsort(1, descending=True)[:, :5])
n5 = min(len(self.model.names), 5)
self.pred.append(preds.argsort(1, descending=True)[:, :n5])
self.targets.append(batch['cls'])
def finalize_metrics(self, *args, **kwargs):

Loading…
Cancel
Save