`ultralytics 8.2.29` new fractional AutoBatch feature (#13446)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/13462/head v8.2.29
Glenn Jocher 6 months ago committed by GitHub
parent 2fe0946376
commit 6a234f3639
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 30
      CONTRIBUTING.md
  2. 10
      docs/en/modes/train.md
  3. 10
      docs/en/usage/cfg.md
  4. 6
      examples/README.md
  5. 2
      ultralytics/__init__.py
  6. 22
      ultralytics/cfg/__init__.py
  7. 24
      ultralytics/cfg/models/README.md
  8. 2
      ultralytics/engine/exporter.py
  9. 9
      ultralytics/engine/trainer.py
  10. 2
      ultralytics/models/yolo/model.py
  11. 8
      ultralytics/utils/autobatch.py
  12. 16
      ultralytics/utils/torch_utils.py

@ -1,12 +1,12 @@
---
comments: true
description: Learn how to contribute to Ultralytics YOLO projects – guidelines for pull requests, reporting bugs, code conduct and CLA signing.
keywords: Ultralytics, YOLO, open-source, contribute, pull request, bug report, coding guidelines, CLA, code of conduct, GitHub
description: Learn how to contribute to Ultralytics YOLO open-source repositories. Follow guidelines for pull requests, code of conduct, and bug reporting.
keywords: Ultralytics, YOLO, open-source, contribution, pull request, code of conduct, bug reporting, GitHub, CLA, Google-style docstrings
---
# Contributing to Ultralytics Open-Source YOLO Repositories
First of all, thank you for your interest in contributing to Ultralytics open-source YOLO repositories! Your contributions will help improve the project and benefit the community. This document provides guidelines and best practices to get you started.
Thank you for your interest in contributing to Ultralytics open-source YOLO repositories! Your contributions will enhance the project and benefit the entire community. This document provides guidelines and best practices to help you get started.
## Table of Contents
@ -21,27 +21,27 @@ First of all, thank you for your interest in contributing to Ultralytics open-so
## Code of Conduct
All contributors are expected to adhere to the [Code of Conduct](https://docs.ultralytics.com/help/code_of_conduct/) to ensure a welcoming and inclusive environment for everyone.
All contributors must adhere to the [Code of Conduct](code_of_conduct.md) to ensure a welcoming and inclusive environment for everyone.
## Contributing via Pull Requests
We welcome contributions in the form of pull requests. To make the review process smoother, please follow these guidelines:
We welcome contributions in the form of pull requests. To streamline the review process, please follow these guidelines:
1. **[Fork the repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/fork-a-repo)**: Fork the Ultralytics YOLO repository to your own GitHub account.
1. **[Fork the repository](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/fork-a-repo)**: Fork the Ultralytics YOLO repository to your GitHub account.
2. **[Create a branch](https://docs.github.com/en/desktop/making-changes-in-a-branch/managing-branches-in-github-desktop)**: Create a new branch in your forked repository with a descriptive name for your changes.
3. **Make your changes**: Make the changes you want to contribute. Ensure that your changes follow the coding style of the project and do not introduce new errors or warnings.
3. **Make your changes**: Ensure that your changes follow the project's coding style and do not introduce new errors or warnings.
4. **[Test your changes](https://github.com/ultralytics/ultralytics/tree/main/tests)**: Test your changes locally to ensure that they work as expected and do not introduce new issues.
4. **[Test your changes](https://github.com/ultralytics/ultralytics/tree/main/tests)**: Test your changes locally to ensure they work as expected and do not introduce new issues.
5. **[Commit your changes](https://docs.github.com/en/desktop/making-changes-in-a-branch/committing-and-reviewing-changes-to-your-project-in-github-desktop)**: Commit your changes with a descriptive commit message. Make sure to include any relevant issue numbers in your commit message.
5. **[Commit your changes](https://docs.github.com/en/desktop/making-changes-in-a-branch/committing-and-reviewing-changes-to-your-project-in-github-desktop)**: Commit your changes with a descriptive commit message. Include any relevant issue numbers in your commit message.
6. **[Create a pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request)**: Create a pull request from your forked repository to the main Ultralytics YOLO repository. In the pull request description, provide a clear explanation of your changes and how they improve the project.
6. **[Create a pull request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request)**: Create a pull request from your forked repository to the main Ultralytics YOLO repository. Provide a clear explanation of your changes and how they improve the project.
### CLA Signing
Before we can accept your pull request, you need to sign a [Contributor License Agreement (CLA)](https://docs.ultralytics.com/help/CLA/). This is a legal document stating that you agree to the terms of contributing to the Ultralytics YOLO repositories. The CLA ensures that your contributions are properly licensed and that the project can continue to be distributed under the AGPL-3.0 license.
Before we can accept your pull request, you must sign a [Contributor License Agreement (CLA)](CLA.md). This legal document ensures that your contributions are properly licensed and that the project can continue to be distributed under the AGPL-3.0 license.
To sign the CLA, follow the instructions provided by the CLA bot after you submit your PR and add a comment in your PR saying:
@ -51,7 +51,7 @@ I have read the CLA Document and I sign the CLA
### Google-Style Docstrings
When adding new functions or classes, please include a [Google-style docstring](https://google.github.io/styleguide/pyguide.html) to provide clear and concise documentation for other developers. This will help ensure that your contributions are easy to understand and maintain.
When adding new functions or classes, include a [Google-style docstring](https://google.github.io/styleguide/pyguide.html) to provide clear and concise documentation for other developers. This helps ensure your contributions are easy to understand and maintain.
#### Google-style
@ -60,7 +60,7 @@ This example shows a Google-style docstring. Note that both input and output `ty
```python
def example_function(arg1, arg2=4):
"""
Example function that demonstrates Google-style docstrings.
This example shows a Google-style docstring. Note that both input and output `types` must always be enclosed by parentheses, i.e., `(bool)`.
Args:
arg1 (int): The first argument.
@ -84,7 +84,7 @@ This example shows both a Google-style docstring and argument and return type hi
```python
def example_function(arg1: int, arg2: int = 4) -> bool:
"""
Example function that demonstrates Google-style docstrings.
This example shows both a Google-style docstring and argument and return type hints, though both are not required; one can be used without the other.
Args:
arg1: The first argument.
@ -129,4 +129,4 @@ Users and developers are encouraged to familiarize themselves with the terms of
Thank you for your interest in contributing to [Ultralytics open-source](https://github.com/ultralytics) YOLO projects. Your participation is crucial in shaping the future of our software and fostering a community of innovation and collaboration. Whether you're improving code, reporting bugs, or suggesting features, your contributions make a significant impact.
We're eager to see your ideas in action and appreciate your commitment to advancing object detection technology. Let's continue to grow and innovate together in this exciting open-source journey. Happy coding! 🚀🌟
We look forward to seeing your ideas in action and appreciate your commitment to advancing object detection technology. Let's continue to grow and innovate together in this exciting open-source journey. Happy coding! 🚀🌟

@ -182,7 +182,7 @@ The training settings for YOLO models encompass various hyperparameters and conf
| `epochs` | `100` | Total number of training epochs. Each epoch represents a full pass over the entire dataset. Adjusting this value can affect training duration and model performance. |
| `time` | `None` | Maximum training time in hours. If set, this overrides the `epochs` argument, allowing training to automatically stop after the specified duration. Useful for time-constrained training scenarios. |
| `patience` | `100` | Number of epochs to wait without improvement in validation metrics before early stopping the training. Helps prevent overfitting by stopping training when performance plateaus. |
| `batch` | `16` | Batch size for training, indicating how many images are processed before the model's internal parameters are updated. AutoBatch (`batch=-1`) dynamically adjusts the batch size based on GPU memory availability. |
| `batch` | `16` | Batch size, with three modes: set as an integer (e.g., `batch=16`), auto mode for 60% GPU memory utilization (`batch=-1`), or auto mode with specified utilization fraction (`batch=0.70`). |
| `imgsz` | `640` | Target image size for training. All images are resized to this dimension before being fed into the model. Affects model accuracy and computational complexity. |
| `save` | `True` | Enables saving of training checkpoints and final model weights. Useful for resuming training or model deployment. |
| `save_period` | `-1` | Frequency of saving model checkpoints, specified in epochs. A value of -1 disables this feature. Useful for saving interim models during long training sessions. |
@ -226,6 +226,14 @@ The training settings for YOLO models encompass various hyperparameters and conf
| `val` | `True` | Enables validation during training, allowing for periodic evaluation of model performance on a separate dataset. |
| `plots` | `False` | Generates and saves plots of training and validation metrics, as well as prediction examples, providing visual insights into model performance and learning progression. |
!!! info "Note on Batch-size Settings"
The `batch` argument can be configured in three ways:
- **Fixed Batch Size**: Set an integer value (e.g., `batch=16`), specifying the number of images per batch directly.
- **Auto Mode (60% GPU Memory)**: Use `batch=-1` to automatically adjust batch size for approximately 60% CUDA memory utilization.
- **Auto Mode with Utilization Fraction**: Set a fraction value (e.g., `batch=0.70`) to adjust batch size based on the specified fraction of GPU memory usage.
## Augmentation Settings and Hyperparameters
Augmentation techniques are essential for improving the robustness and performance of YOLO models by introducing variability into the training data, helping the model generalize better to unseen data. The following table outlines the purpose and effect of each augmentation argument:

@ -91,7 +91,7 @@ The training settings for YOLO models encompass various hyperparameters and conf
| `epochs` | `100` | Total number of training epochs. Each epoch represents a full pass over the entire dataset. Adjusting this value can affect training duration and model performance. |
| `time` | `None` | Maximum training time in hours. If set, this overrides the `epochs` argument, allowing training to automatically stop after the specified duration. Useful for time-constrained training scenarios. |
| `patience` | `100` | Number of epochs to wait without improvement in validation metrics before early stopping the training. Helps prevent overfitting by stopping training when performance plateaus. |
| `batch` | `16` | Batch size for training, indicating how many images are processed before the model's internal parameters are updated. AutoBatch (`batch=-1`) dynamically adjusts the batch size based on GPU memory availability. |
| `batch` | `16` | Batch size, with three modes: set as an integer (e.g., `batch=16`), auto mode for 60% GPU memory utilization (`batch=-1`), or auto mode with specified utilization fraction (`batch=0.70`). |
| `imgsz` | `640` | Target image size for training. All images are resized to this dimension before being fed into the model. Affects model accuracy and computational complexity. |
| `save` | `True` | Enables saving of training checkpoints and final model weights. Useful for resuming training or model deployment. |
| `save_period` | `-1` | Frequency of saving model checkpoints, specified in epochs. A value of -1 disables this feature. Useful for saving interim models during long training sessions. |
@ -135,6 +135,14 @@ The training settings for YOLO models encompass various hyperparameters and conf
| `val` | `True` | Enables validation during training, allowing for periodic evaluation of model performance on a separate dataset. |
| `plots` | `False` | Generates and saves plots of training and validation metrics, as well as prediction examples, providing visual insights into model performance and learning progression. |
!!! info "Note on Batch-size Settings"
The `batch` argument can be configured in three ways:
- **Fixed Batch Size**: Set an integer value (e.g., `batch=16`), specifying the number of images per batch directly.
- **Auto Mode (60% GPU Memory)**: Use `batch=-1` to automatically adjust batch size for approximately 60% CUDA memory utilization.
- **Auto Mode with Utilization Fraction**: Set a fraction value (e.g., `batch=0.70`) to adjust batch size based on the specified fraction of GPU memory usage.
[Train Guide](../modes/train.md){ .md-button }
## Predict Settings

@ -25,11 +25,11 @@ This repository features a collection of real-world applications and walkthrough
We greatly appreciate contributions from the community, including examples, applications, and guides. If you'd like to contribute, please follow these guidelines:
1. Create a pull request (PR) with the title prefix `[Example]`, adding your new example folder to the `examples/` directory within the repository.
2. Make sure your project adheres to the following standards:
1. **Create a pull request (PR)** with the title prefix `[Example]`, adding your new example folder to the `examples/` directory within the repository.
2. **Ensure your project adheres to the following standards:**
- Makes use of the `ultralytics` package.
- Includes a `README.md` with clear instructions for setting up and running the example.
- Refrains from adding large files or dependencies unless they are absolutely necessary for the example.
- Avoids adding large files or dependencies unless they are absolutely necessary for the example.
- Contributors should be willing to provide support for their examples and address related issues.
For more detailed information and guidance on contributing, please visit our [contribution documentation](https://docs.ultralytics.com/help/contributing).

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.28"
__version__ = "8.2.29"
import os

@ -95,10 +95,19 @@ CLI_HELP_MSG = f"""
"""
# Define keys for arg type checks
CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time", "workspace"}
CFG_FRACTION_KEYS = {
CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0
"warmup_epochs",
"box",
"cls",
"dfl",
"degrees",
"shear",
"time",
"workspace",
"batch",
}
CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0
"dropout",
"iou",
"lr0",
"lrf",
"momentum",
@ -121,11 +130,10 @@ CFG_FRACTION_KEYS = {
"conf",
"iou",
"fraction",
} # fraction floats 0.0 - 1.0
CFG_INT_KEYS = {
}
CFG_INT_KEYS = { # integer-only arguments
"epochs",
"patience",
"batch",
"workers",
"seed",
"close_mosaic",
@ -136,7 +144,7 @@ CFG_INT_KEYS = {
"nbs",
"save_period",
}
CFG_BOOL_KEYS = {
CFG_BOOL_KEYS = { # boolean-only arguments
"save",
"exist_ok",
"verbose",

@ -1,6 +1,6 @@
## Models
Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image segmentation tasks.
Welcome to the [Ultralytics](https://ultralytics.com) Models directory! Here you will find a wide variety of pre-configured model configuration files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image segmentation tasks.
These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms, from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this directory provides a great starting point for your custom model development needs.
@ -8,26 +8,34 @@ To get started, simply browse through the models in this directory and find one
### Usage
Model `*.yaml` files may be used directly in the Command Line Interface (CLI) with a `yolo` command:
Model `*.yaml` files may be used directly in the [Command Line Interface (CLI)](https://docs.ultralytics.com/usage/cli) with a `yolo` command:
```bash
# Train a YOLOv8n model using the coco8 dataset for 100 epochs
yolo task=detect mode=train model=yolov8n.yaml data=coco8.yaml epochs=100
```
They may also be used directly in a Python environment, and accepts the same [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above:
They may also be used directly in a Python environment, and accept the same [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above:
```python
from ultralytics import YOLO
model = YOLO("model.yaml") # build a YOLOv8n model from scratch
# YOLO("model.pt") use pre-trained model if available
model.info() # display model information
model.train(data="coco8.yaml", epochs=100) # train the model
# Initialize a YOLOv8n model from a YAML configuration file
model = YOLO("model.yaml")
# If a pre-trained model is available, use it instead
# model = YOLO("model.pt")
# Display model information
model.info()
# Train the model using the COCO8 dataset for 100 epochs
model.train(data="coco8.yaml", epochs=100)
```
## Pre-trained Model Architectures
Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available.
Ultralytics supports many model architectures. Visit [Ultralytics Models](https://docs.ultralytics.com/models) to view detailed information and usage. Any of these models can be used by loading their configurations or pretrained checkpoints if available.
## Contribute New Models

@ -126,7 +126,7 @@ def gd_outputs(gd):
def try_export(inner_func):
"""YOLOv8 export decorator, i..e @try_export."""
"""YOLOv8 export decorator, i.e. @try_export."""
inner_args = get_default_args(inner_func)
def outer_func(*args, **kwargs):

@ -269,8 +269,13 @@ class BaseTrainer:
self.stride = gs # for multiscale training
# Batch size
if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
self.args.batch = self.batch_size = check_train_batch_size(
model=self.model,
imgsz=self.args.imgsz,
amp=self.amp,
batch=self.batch_size,
)
# Dataloaders
batch_size = self.batch_size // max(world_size, 1)

@ -92,7 +92,7 @@ class YOLOWorld(Model):
Set classes.
Args:
classes (List(str)): A list of categories i.e ["person"].
classes (List(str)): A list of categories i.e. ["person"].
"""
self.model.set_classes(classes)
# Remove background if it's given

@ -10,9 +10,9 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
from ultralytics.utils.torch_utils import profile
def check_train_batch_size(model, imgsz=640, amp=True):
def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
"""
Check YOLO training batch size using the autobatch() function.
Compute optimal YOLO training batch size using the autobatch() function.
Args:
model (torch.nn.Module): YOLO model to check batch size for.
@ -24,7 +24,7 @@ def check_train_batch_size(model, imgsz=640, amp=True):
"""
with torch.cuda.amp.autocast(amp):
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)
def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
@ -43,7 +43,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
# Check device
prefix = colorstr("AutoBatch: ")
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz}")
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")
device = next(model.parameters()).device # get model device
if device.type == "cpu":
LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}")

@ -146,11 +146,17 @@ def select_device(device="", batch=0, newline=False, verbose=True):
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
n = len(devices) # device count
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
raise ValueError(
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
)
if n > 1: # multi-GPU
if batch < 1:
raise ValueError(
"AutoBatch with batch<1 not supported for Multi-GPU training, "
"please specify a valid batch size, i.e. batch=16."
)
if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count
raise ValueError(
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
)
space = " " * (len(s) + 1)
for i, d in enumerate(devices):
p = torch.cuda.get_device_properties(i)

Loading…
Cancel
Save