`ultralytics 8.0.171` new SAHI guide and callbacks fix (#4748)

Co-authored-by: chuzihang <49548797+Aria-Leo@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Sanghyun Choi <farewell518@gmail.com>
Co-authored-by: Awsome <1579093407@qq.com>
Co-authored-by: ConanQZ <49194386+ConanQZ@users.noreply.github.com>
pull/4756/head v8.0.171
Glenn Jocher 1 year ago committed by GitHub
parent aa9133bb88
commit 577d066fb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      .pre-commit-config.yaml
  2. 11
      docs/guides/index.md
  3. 147
      docs/guides/sahi-tiled-inference.md
  4. 10
      docs/integrations/index.md
  5. 1
      mkdocs.yml
  6. 2
      ultralytics/__init__.py
  7. 10
      ultralytics/cfg/models/README.md
  8. 8
      ultralytics/engine/predictor.py
  9. 2
      ultralytics/models/yolo/classify/val.py
  10. 2
      ultralytics/models/yolo/detect/val.py
  11. 10
      ultralytics/trackers/README.md
  12. 13
      ultralytics/utils/callbacks/base.py
  13. 2
      ultralytics/utils/metrics.py
  14. 24
      ultralytics/utils/plotting.py
  15. 4
      ultralytics/utils/tal.py

@ -40,7 +40,7 @@ repos:
name: YAPF formatting
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
rev: 0.7.17
hooks:
- id: mdformat
name: MD formatting

@ -1,7 +1,7 @@
---
comments: true
description: In-depth exploration of Ultralytics' YOLO. Learn about the YOLO object detection model, how to train it on custom data, multi-GPU training, exporting, predicting, deploying, and more.
keywords: Ultralytics, YOLO, Deep Learning, Object detection, PyTorch, Tutorial, Multi-GPU training, Custom data training
keywords: Ultralytics, YOLO, Deep Learning, Object detection, PyTorch, Tutorial, Multi-GPU training, Custom data training, SAHI, Tiled Inference
---
# Comprehensive Tutorials to Ultralytics YOLO
@ -16,5 +16,12 @@ Here's a compilation of in-depth guides to help you master different aspects of
* [K-Fold Cross Validation](kfold-cross-validation.md) 🚀 NEW: Learn how to improve model generalization using K-Fold cross-validation technique.
* [Hyperparameter Tuning](hyperparameter-tuning.md) 🚀 NEW: Discover how to optimize your YOLO models by fine-tuning hyperparameters using the Tuner class and genetic evolution algorithms.
* [Using YOLOv8 with SAHI for Sliced Inference](sahi-tiled-inference.md) 🚀 NEW: Comprehensive guide on leveraging SAHI's sliced inference capabilities with YOLOv8 for object detection in high-resolution images.
Note: More guides about training, exporting, predicting, and deploying with Ultralytics YOLO are coming soon. Stay tuned!
## Contribute to Our Guides
We welcome contributions from the community! If you've mastered a particular aspect of Ultralytics YOLO that's not yet covered in our guides, we encourage you to share your expertise. Writing a guide is a great way to give back to the community and help us make our documentation more comprehensive and user-friendly.
To get started, please read our [Contributing Guide](https://docs.ultralytics.com/help/contributing) for guidelines on how to open up a Pull Request (PR) 🛠. We look forward to your contributions!
Let's work together to make the Ultralytics YOLO ecosystem more robust and versatile 🙏!

@ -0,0 +1,147 @@
---
comments: true
description: A comprehensive guide on how to use YOLOv8 with SAHI for standard and sliced inference in object detection tasks.
keywords: YOLOv8, SAHI, Sliced Inference, Object Detection, Ultralytics, Large Scale Image Analysis, High-Resolution Imagery
---
# Ultralytics Docs: Using YOLOv8 with SAHI for Sliced Inference
Welcome to the Ultralytics documentation on how to use YOLOv8 with SAHI (Slicing Aided Hyper Inference). In this comprehensive guide, we'll discuss what SAHI is, the benefits of sliced inference, and how to use SAHI with YOLOv8 for object detection tasks.
![SAHI Sliced Inference](https://raw.githubusercontent.com/obss/sahi/main/resources/sliced_inference.gif)
## Table of Contents
1. [Introduction to SAHI](#introduction-to-sahi)
2. [What is Sliced Inference?](#what-is-sliced-inference)
3. [Installation and Preparation](#installation-and-preparation)
4. [Standard Inference with YOLOv8](#standard-inference-with-yolov8)
5. [Sliced Inference with YOLOv8](#sliced-inference-with-yolov8)
6. [Handling Prediction Results](#handling-prediction-results)
7. [Batch Prediction](#batch-prediction)
## Introduction to SAHI
SAHI is a powerful library aimed at performing efficient and accurate object detection over slices of an image, particularly useful for large scale and high-resolution imagery. It integrates seamlessly with YOLO models and allows for a more efficient usage of computational resources.
## What is Sliced Inference?
Sliced Inference is a technique that divides a large image into smaller slices, performs object detection on each slice, and then aggregates the results back onto the original image. This method is especially beneficial when dealing with high-resolution images as it significantly reduces the computational load without sacrificing detection accuracy.
## Installation and Preparation
### Installation
To get started, install the latest versions of SAHI and Ultralytics:
```bash
pip install -U ultralytics sahi
```
### Import Modules and Download Resources
Here's how to import the necessary modules and download a YOLOv8 model and some test images:
```python
from sahi.utils.yolov8 import download_yolov8s_model
from sahi import AutoDetectionModel
from sahi.utils.cv import read_image
from sahi.utils.file import download_from_url
from sahi.predict import get_prediction, get_sliced_prediction, predict
from pathlib import Path
from IPython.display import Image
# Download YOLOv8 model
yolov8_model_path = "models/yolov8s.pt"
download_yolov8s_model(yolov8_model_path)
# Download test images
download_from_url('https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/small-vehicles1.jpeg', 'demo_data/small-vehicles1.jpeg')
download_from_url('https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/terrain2.png', 'demo_data/terrain2.png')
```
## Standard Inference with YOLOv8
### Instantiate the Model
You can instantiate a YOLOv8 model for object detection like this:
```python
detection_model = AutoDetectionModel.from_pretrained(
model_type='yolov8',
model_path=yolov8_model_path,
confidence_threshold=0.3,
device="cpu", # or 'cuda:0'
)
```
### Perform Standard Prediction
Perform standard inference using an image path or a numpy image.
```python
# With an image path
result = get_prediction("demo_data/small-vehicles1.jpeg", detection_model)
# With a numpy image
result = get_prediction(read_image("demo_data/small-vehicles1.jpeg"), detection_model)
```
### Visualize Results
Export and visualize the predicted bounding boxes and masks:
```python
result.export_visuals(export_dir="demo_data/")
Image("demo_data/prediction_visual.png")
```
## Sliced Inference with YOLOv8
Perform sliced inference by specifying the slice dimensions and overlap ratios:
```python
result = get_sliced_prediction(
"demo_data/small-vehicles1.jpeg",
detection_model,
slice_height=256,
slice_width=256,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2
)
```
## Handling Prediction Results
SAHI provides a `PredictionResult` object, which can be converted into various annotation formats:
```python
# Access the object prediction list
object_prediction_list = result.object_prediction_list
# Convert to COCO annotation, COCO prediction, imantics, and fiftyone formats
result.to_coco_annotations()[:3]
result.to_coco_predictions(image_id=1)[:3]
result.to_imantics_annotations()[:3]
result.to_fiftyone_detections()[:3]
```
## Batch Prediction
For batch prediction on a directory of images:
```python
predict(
model_type="yolov8",
model_path="path/to/yolov8n.pt",
model_device="cpu", # or 'cuda:0'
model_confidence_threshold=0.4,
source="path/to/dir",
slice_height=256,
slice_width=256,
overlap_height_ratio=0.2,
overlap_width_ratio=0.2,
)
```
That's it! Now you're equipped to use YOLOv8 with SAHI for both standard and sliced inference.

@ -59,3 +59,13 @@ We also support a variety of model export formats for deployment in different en
| [NCNN](https://github.com/Tencent/ncnn) | `ncnn` | `yolov8n_ncnn_model/` | ✅ | `imgsz`, `half` |
Explore the links to learn more about each integration and how to get the most out of them with Ultralytics.
## Contribute to Our Integrations
We're always excited to see how the community integrates Ultralytics YOLO with other technologies, tools, and platforms! If you have successfully integrated YOLO with a new system or have valuable insights to share, consider contributing to our Integrations Docs.
By writing a guide or tutorial, you can help expand our documentation and provide real-world examples that benefit the community. It's an excellent way to contribute to the growing ecosystem around Ultralytics YOLO.
To contribute, please check out our [Contributing Guide](https://docs.ultralytics.com/help/contributing) for instructions on how to submit a Pull Request (PR) 🛠. We eagerly await your contributions!
Let's collaborate to make the Ultralytics YOLO ecosystem more expansive and feature-rich 🙏!

@ -216,6 +216,7 @@ nav:
- guides/index.md
- K-Fold Cross Validation: guides/kfold-cross-validation.md
- Hyperparameter Tuning: guides/hyperparameter-tuning.md
- SAHI Tiled Inference: guides/sahi-tiled-inference.md
- Integrations:
- integrations/index.md
- OpenVINO: integrations/openvino.md

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.170'
__version__ = '8.0.171'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM

@ -30,6 +30,12 @@ model.train(data="coco128.yaml", epochs=100) # train the model
Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available.
## Contributing New Models
## Contribute New Models
If you've developed a new model architecture or have improvements for existing models that you'd like to contribute to the Ultralytics community, please submit your contribution in a new Pull Request. For more details, visit our [Contributing Guide](https://docs.ultralytics.com/help/contributing).
Have you trained a new YOLO variant or achieved state-of-the-art performance with specific tuning? We'd love to showcase your work in our Models section! Contributions from the community in the form of new models, architectures, or optimizations are highly valued and can significantly enrich our repository.
By contributing to this section, you're helping us offer a wider array of model choices and configurations to the community. It's a fantastic way to share your knowledge and expertise while making the Ultralytics YOLO ecosystem even more versatile.
To get started, please consult our [Contributing Guide](https://docs.ultralytics.com/help/contributing) for step-by-step instructions on how to submit a Pull Request (PR) 🛠. Your contributions are eagerly awaited!
Let's join hands to extend the range and capabilities of the Ultralytics YOLO models 🙏!

@ -121,11 +121,11 @@ class BasePredictor:
im = np.ascontiguousarray(im) # contiguous
im = torch.from_numpy(im)
img = im.to(self.device)
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
im = im.to(self.device)
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
if not_tensor:
img /= 255 # 0 - 255 to 0.0 - 1.0
return img
im /= 255 # 0 - 255 to 0.0 - 1.0
return im
def inference(self, im, *args, **kwargs):
visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,

@ -42,7 +42,7 @@ class ClassificationValidator(BaseValidator):
"""Initialize confusion matrix, class names, and top-1 and top-5 accuracy."""
self.names = model.names
self.nc = len(model.names)
self.confusion_matrix = ConfusionMatrix(nc=self.nc, task='classify')
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf, task='classify')
self.pred = []
self.targets = []

@ -68,7 +68,7 @@ class DetectionValidator(BaseValidator):
self.nc = len(model.names)
self.metrics.names = self.names
self.metrics.plot = self.args.plots
self.confusion_matrix = ConfusionMatrix(nc=self.nc)
self.confusion_matrix = ConfusionMatrix(nc=self.nc, conf=self.args.conf)
self.seen = 0
self.jdict = []
self.stats = []

@ -83,6 +83,12 @@ yolo pose track source=... tracker=...
By default, trackers will use the configuration in `ultralytics/cfg/trackers`. We also support using a modified tracker config file. Please refer to the tracker config files in `ultralytics/cfg/trackers`.
## Contributing New Trackers
## Contribute to Our Trackers Section
If you've developed a new tracker architecture or have improvements for existing trackers that you'd like to contribute to the Ultralytics community, please submit your contribution in a new Pull Request. For more details, visit our [Contributing Guide](https://docs.ultralytics.com/help/contributing).
Are you proficient in multi-object tracking and have successfully implemented or adapted a tracking algorithm with Ultralytics YOLO? We invite you to contribute to our Trackers section! Your real-world applications and solutions could be invaluable for users working on tracking tasks.
By contributing to this section, you help expand the scope of tracking solutions available within the Ultralytics YOLO framework, adding another layer of functionality and utility for the community.
To initiate your contribution, please refer to our [Contributing Guide](https://docs.ultralytics.com/help/contributing) for comprehensive instructions on submitting a Pull Request (PR) 🛠. We are excited to see what you bring to the table!
Together, let's enhance the tracking capabilities of the Ultralytics YOLO ecosystem 🙏!

@ -198,7 +198,8 @@ def add_integration_callbacks(instance):
"""
# Load HUB callbacks
from .hub import callbacks
from .hub import callbacks as hub_cb
callbacks_list = [hub_cb]
# Load training callbacks
if 'Trainer' in instance.__class__.__name__:
@ -210,13 +211,15 @@ def add_integration_callbacks(instance):
from .raytune import callbacks as tune_cb
from .tensorboard import callbacks as tb_cb
from .wb import callbacks as wb_cb
callbacks.update({**clear_cb, **comet_cb, **dvc_cb, **mlflow_cb, **neptune_cb, **tune_cb, **tb_cb, **wb_cb})
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
# Load export callbacks (patch to avoid CoreML protobuf error)
if 'Exporter' in instance.__class__.__name__:
from .tensorboard import callbacks as tb_cb
callbacks.update(tb_cb)
callbacks_list.append(tb_cb)
# Add the callbacks to the callbacks dictionary
for callbacks in callbacks_list:
for k, v in callbacks.items():
if v not in instance.callbacks[k]: # prevent duplicate callbacks addition
instance.callbacks[k].append(v) # callback[name].append(func)
if v not in instance.callbacks[k]:
instance.callbacks[k].append(v)

@ -189,7 +189,7 @@ class ConfusionMatrix:
self.task = task
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == 'detect' else np.zeros((nc, nc))
self.nc = nc # number of classes
self.conf = conf
self.conf = 0.25 if conf is None else conf # argument may be None from default cfg
self.iou_thres = iou_thres
def process_cls_preds(self, preds, targets):

@ -12,11 +12,10 @@ import torch
from PIL import Image, ImageDraw, ImageFont
from PIL import __version__ as pil_version
from ultralytics.utils import LOGGER, TryExcept, plt_settings, threaded
from ultralytics.utils import LOGGER, TryExcept, ops, plt_settings, threaded
from .checks import check_font, check_version, is_ascii
from .files import increment_path
from .ops import clip_boxes, scale_image, xywh2xyxy, xyxy2xywh
class Colors:
@ -163,7 +162,7 @@ class Annotator:
im_gpu = im_gpu * inv_alph_masks[-1] + mcs
im_mask = (im_gpu * 255)
im_mask_np = im_mask.byte().cpu().numpy()
self.im[:] = im_mask_np if retina_masks else scale_image(im_mask_np, self.im.shape)
self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
if self.pil:
# Convert im back to PIL and update draw
self.fromarray(self.im)
@ -268,8 +267,9 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
import pandas as pd
import seaborn as sn
# Filter matplotlib>=3.7.2 warning
# Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
warnings.filterwarnings('ignore', category=UserWarning, message='The figure layout has changed to tight')
warnings.filterwarnings('ignore', category=FutureWarning)
# Plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
@ -285,8 +285,8 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
# Matplotlib labels
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
with contextlib.suppress(Exception): # color histogram bars by class
[y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
for i in range(nc):
y[2].patches[i].set_color([x / 255 for x in colors(i)])
ax[0].set_ylabel('instances')
if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names)))
@ -298,7 +298,7 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(''), on_plot=None):
# Rectangles
boxes[:, 0:2] = 0.5 # center
boxes = xywh2xyxy(boxes) * 1000
boxes = ops.xywh2xyxy(boxes) * 1000
img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
for cls, box in zip(cls[:500], boxes[:500]):
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
@ -348,12 +348,12 @@ def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False,
if not isinstance(xyxy, torch.Tensor): # may be list
xyxy = torch.stack(xyxy)
b = xyxy2xywh(xyxy.view(-1, 4)) # boxes
b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes
if square:
b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
xyxy = xywh2xyxy(b).long()
clip_boxes(xyxy, im.shape)
xyxy = ops.xywh2xyxy(b).long()
ops.clip_boxes(xyxy, im.shape)
crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
if save:
file.parent.mkdir(parents=True, exist_ok=True) # make directory
@ -425,7 +425,7 @@ def plot_images(images,
classes = cls[idx].astype('int')
if len(bboxes):
boxes = xywh2xyxy(bboxes[idx, :4]).T
boxes = ops.xywh2xyxy(bboxes[idx, :4]).T
labels = bboxes.shape[1] == 4 # labels if no conf column
conf = None if labels else bboxes[idx, 4] # check for confidence presence (label vs pred)
@ -554,7 +554,7 @@ def output_to_target(output, max_det=300):
for i, o in enumerate(output):
box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
j = torch.full((conf.shape[0], 1), i)
targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
targets.append(torch.cat((j, cls, ops.xyxy2xywh(box), conf), 1))
targets = torch.cat(targets, 0).numpy()
return targets[:, 0], targets[:, 1], targets[:, 2:]

@ -14,7 +14,7 @@ def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
Select the positive anchor center in gt.
Args:
xy_centers (Tensor): shape(h*w, 4)
xy_centers (Tensor): shape(h*w, 2)
gt_bboxes (Tensor): shape(b, n_boxes, 4)
Returns:
@ -228,7 +228,7 @@ class TaskAlignedAssigner(nn.Module):
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
# Assigned target scores

Loading…
Cancel
Save