`ultralytics 8.0.226` Validator Path and Tuner space (#6901)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Muhammad Rizwan Munawar <chr043416@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: DennisJ <106725464+DennisJcy@users.noreply.github.com>
Co-authored-by: Kirill Ionkin <56236621+kirill-ionkin@users.noreply.github.com>
pull/6907/head v8.0.226
Glenn Jocher 1 year ago committed by GitHub
parent 6e660dfaaf
commit 412eb57fca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      docs/en/models/mobile-sam.md
  2. 4
      docs/en/models/sam.md
  3. 16
      docs/en/reference/solutions/heatmap.md
  4. 1
      docs/mkdocs.yml
  5. 2
      ultralytics/__init__.py
  6. 4
      ultralytics/engine/exporter.py
  7. 14
      ultralytics/engine/tuner.py
  8. 2
      ultralytics/engine/validator.py
  9. 17
      ultralytics/nn/autobackend.py
  10. 6
      ultralytics/solutions/ai_gym.py
  11. 7
      ultralytics/solutions/heatmap.py
  12. 7
      ultralytics/solutions/object_counter.py

@ -22,7 +22,7 @@ This table presents the available models with their specific pre-trained weights
| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export |
|------------|---------------------|----------------------------------------------|-----------|------------|----------|--------|
| MobileSAM | `mobile_sam.pt` | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | |
| MobileSAM | `mobile_sam.pt` | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | |
## Adapting from SAM to MobileSAM

@ -32,8 +32,8 @@ This table presents the available models with their specific pre-trained weights
| Model Type | Pre-trained Weights | Tasks Supported | Inference | Validation | Training | Export |
|------------|---------------------|----------------------------------------------|-----------|------------|----------|--------|
| SAM base | `sam_b.pt` | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | |
| SAM large | `sam_l.pt` | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | |
| SAM base | `sam_b.pt` | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | |
| SAM large | `sam_l.pt` | [Instance Segmentation](../tasks/segment.md) | ✅ | ❌ | ❌ | |
## How to Use SAM: Versatility and Power in Image Segmentation

@ -0,0 +1,16 @@
---
description: Explore Ultralytics YOLO's advanced Heatmaps feature designed to highlight areas of interest, providing an immediate, impactful way to interpret spatial information.
keywords: Ultralytics, YOLO, heatmaps, object tracking, data visualization, real-time tracking, machine learning, object counting, computer vision, retail analytics, YOLOv8, artificial intelligence
---
# Reference for `ultralytics/solutions/heatmap.py`
!!! Note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/heatmap.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/heatmap.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/solutions/heatmap.py) 🛠. Thank you 🙏!
<br><br>
## ::: ultralytics.solutions.heatmap.Heatmap
<br><br>

@ -408,6 +408,7 @@ nav:
- solutions:
- ai_gym: reference/solutions/ai_gym.md
- object_counter: reference/solutions/object_counter.md
- heatmap: reference/solutions/heatmap.md
- trackers:
- basetrack: reference/trackers/basetrack.md
- bot_sort: reference/trackers/bot_sort.md

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.225'
__version__ = '8.0.226'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM

@ -64,7 +64,7 @@ import torch
from ultralytics.cfg import get_cfg
from ultralytics.data.dataset import YOLODataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.nn.autobackend import check_class_names
from ultralytics.nn.autobackend import check_class_names, default_class_names
from ultralytics.nn.modules import C2f, Detect, RTDETRDecoder
from ultralytics.nn.tasks import DetectionModel, SegmentationModel
from ultralytics.utils import (ARM64, DEFAULT_CFG, LINUX, LOGGER, MACOS, ROOT, WINDOWS, __version__, callbacks,
@ -172,6 +172,8 @@ class Exporter:
self.device = select_device('cpu' if self.args.device is None else self.args.device)
# Checks
if not hasattr(model, 'names'):
model.names = default_class_names()
model.names = check_class_names(model.names)
if self.args.half and onnx and self.device.type == 'cpu':
LOGGER.warning('WARNING ⚠ half=True only compatible with GPU export, i.e. use device=0')

@ -56,6 +56,14 @@ class Tuner:
model = YOLO('yolov8n.pt')
model.tune(data='coco8.yaml', epochs=10, iterations=300, optimizer='AdamW', plots=False, save=False, val=False)
```
Tune with custom search space.
```python
from ultralytics import YOLO
model = YOLO('yolov8n.pt')
model.tune(space={key1: val1, key2: val2}) # custom search space dictionary
```
"""
def __init__(self, args=DEFAULT_CFG, _callbacks=None):
@ -65,10 +73,9 @@ class Tuner:
Args:
args (dict, optional): Configuration for hyperparameter evolution.
"""
self.args = get_cfg(overrides=args)
self.space = { # key: (min, max, gain(optional))
self.space = args.pop('space', None) or { # key: (min, max, gain(optional))
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
'lr0': (1e-5, 1e-1),
'lr0': (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
'lrf': (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
'momentum': (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4
@ -90,6 +97,7 @@ class Tuner:
'mosaic': (0.0, 1.0), # image mixup (probability)
'mixup': (0.0, 1.0), # image mixup (probability)
'copy_paste': (0.0, 1.0)} # segment copy-paste (probability)
self.args = get_cfg(overrides=args)
self.tune_dir = get_save_dir(self.args, name='tune')
self.tune_csv = self.tune_dir / 'tune_results.csv'
self.callbacks = _callbacks or callbacks.get_default_callbacks()

@ -135,7 +135,7 @@ class BaseValidator:
self.args.batch = 1 # export.py models default to batch-size 1
LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if isinstance(self.args.data, str) and self.args.data.split('.')[-1] in ('yaml', 'yml'):
if str(self.args.data).split('.')[-1] in ('yaml', 'yml'):
self.data = check_det_dataset(self.args.data)
elif self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data, split=self.args.split)

@ -40,6 +40,14 @@ def check_class_names(names):
return names
def default_class_names(data=None):
"""Applies default class names to an input YAML file or returns numerical class names."""
if data:
with contextlib.suppress(Exception):
return yaml_load(check_yaml(data))['names']
return {i: f'class{i}' for i in range(999)} # return default if above errors
class AutoBackend(nn.Module):
"""
Handles dynamic backend selection for running inference using Ultralytics YOLO models.
@ -315,7 +323,7 @@ class AutoBackend(nn.Module):
# Check names
if 'names' not in locals(): # names missing
names = self._apply_default_class_names(data)
names = default_class_names(data)
names = check_class_names(names)
# Disable gradients
@ -479,13 +487,6 @@ class AutoBackend(nn.Module):
for _ in range(2 if self.jit else 1):
self.forward(im) # warmup
@staticmethod
def _apply_default_class_names(data):
"""Applies default class names to an input YAML file or returns numerical class names."""
with contextlib.suppress(Exception):
return yaml_load(check_yaml(data))['names']
return {i: f'class{i}' for i in range(999)} # return default if above errors
@staticmethod
def _model_type(p='path/to/model.pt'):
"""

@ -2,6 +2,7 @@
import cv2
from ultralytics.utils.checks import check_imshow
from ultralytics.utils.plotting import Annotator
@ -32,6 +33,9 @@ class AIGym:
self.view_img = False
self.annotator = None
# Check if environment support imshow
self.env_check = check_imshow(warn=True)
def set_args(self,
kpts_to_check,
line_thickness=2,
@ -120,7 +124,7 @@ class AIGym:
self.annotator.kpts(k, shape=(640, 640), radius=1, kpt_line=True)
if self.view_img:
if self.env_check and self.view_img:
cv2.imshow('Ultralytics YOLOv8 AI GYM', self.im0)
if cv2.waitKey(1) & 0xFF == ord('q'):
return

@ -5,7 +5,7 @@ from collections import defaultdict
import cv2
import numpy as np
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.checks import check_imshow, check_requirements
from ultralytics.utils.plotting import Annotator
check_requirements('shapely>=2.0.0')
@ -50,6 +50,9 @@ class Heatmap:
self.count_reg_color = (0, 255, 0)
self.region_thickness = 5
# Check if environment support imshow
self.env_check = check_imshow(warn=True)
def set_args(self,
imw,
imh,
@ -155,7 +158,7 @@ class Heatmap:
im0_with_heatmap = cv2.addWeighted(self.im0, 1 - self.heatmap_alpha, heatmap_colored, self.heatmap_alpha, 0)
if self.view_img:
if self.env_check and self.view_img:
self.display_frames(im0_with_heatmap)
return im0_with_heatmap

@ -4,7 +4,7 @@ from collections import defaultdict
import cv2
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.checks import check_imshow, check_requirements
from ultralytics.utils.plotting import Annotator, colors
check_requirements('shapely>=2.0.0')
@ -46,6 +46,9 @@ class ObjectCounter:
self.track_thickness = 2
self.draw_tracks = False
# Check if environment support imshow
self.env_check = check_imshow(warn=True)
def set_args(self,
classes_names,
reg_pts,
@ -136,7 +139,7 @@ class ObjectCounter:
else:
self.in_counts += 1
if self.view_img:
if self.env_check and self.view_img:
incount_label = 'InCount : ' + f'{self.in_counts}'
outcount_label = 'OutCount : ' + f'{self.out_counts}'
self.annotator.count_labels(in_count=incount_label, out_count=outcount_label)

Loading…
Cancel
Save