You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
429 lines
14 KiB
429 lines
14 KiB
--- |
|
comments: true |
|
description: Discover how to extend the utility of the Ultralytics package to support your development process. |
|
keywords: Ultralytics, YOLO, custom, function, workflow, utility, support, |
|
--- |
|
|
|
# Simple Utilities |
|
|
|
<p align="center"> |
|
<img src="https://github.com/ultralytics/ultralytics/assets/62214284/516112de-4567-49f8-b93f-b55a10b79dd7" alt="code with perspective"> |
|
</p> |
|
|
|
The `ultralytics` package comes with a myriad of utilities that can support, enhance, and speed up your workflows. There are many more available, but here are some that will be useful for most developers. They're also a great reference point to use when learning to program. |
|
|
|
## Data |
|
|
|
### YOLO Data Explorer |
|
|
|
[YOLO Explorer](../datasets/explorer/index.md) was added in the `8.1.0` anniversary update and is a powerful tool you can use to better understand your dataset. One of the key functions that YOLO Explorer provides, is the ability to use text queries to find object instances in your dataset. |
|
|
|
### Auto Labeling / Annotations |
|
|
|
Dataset annotation is a very resource intensive and time-consuming process. If you have a YOLO object detection model trained on a reasonable amount of data, you can use it and [SAM](../models/sam.md) to auto-annotate additional data (segmentation format). |
|
|
|
```{ .py .annotate } |
|
from ultralytics.data.annotator import auto_annotate |
|
|
|
auto_annotate(#(1)! |
|
data='path/to/new/data', |
|
det_model='yolov8n.pt', |
|
sam_model='mobile_sam.pt', |
|
device="cuda", |
|
output_dir="path/to/save_labels", |
|
) |
|
``` |
|
|
|
1. Nothing returns from this function |
|
|
|
- [See the reference section for `annotator.auto_annotate`](../reference/data/annotator.md#ultralytics.data.annotator.auto_annotate) for more insight on how the function operates. |
|
|
|
- Use in combination with the [function `segments2boxes`](#convert-segments-to-bounding-boxes) to generate object detection bounding boxes as well |
|
|
|
### Convert COCO into YOLO Format |
|
|
|
Use to convert COCO JSON annotations into proper YOLO format. For object detection (bounding box) datasets, `use_segments` and `use_keypoints` should both be `False` |
|
|
|
```{ .py .annotate } |
|
from ultralytics.data.converter import convert_coco |
|
|
|
convert_coco(#(1)! |
|
'../datasets/coco/annotations/', |
|
use_segments=False, |
|
use_keypoints=False, |
|
cls91to80=True, |
|
) |
|
``` |
|
|
|
1. Nothing returns from this function |
|
|
|
For additional information about the `convert_coco` function, [visit the reference page](../reference/data/converter.md#ultralytics.data.converter.convert_coco) |
|
|
|
### Convert Bounding Boxes to Segments |
|
|
|
With existing `x y w h` bounding box data, convert to segments using the `yolo_bbox2segment` function. The files for images and annotations need to be organized like this: |
|
|
|
``` |
|
data |
|
|__ images |
|
├─ 001.jpg |
|
├─ 002.jpg |
|
├─ .. |
|
└─ NNN.jpg |
|
|__ labels |
|
├─ 001.txt |
|
├─ 002.txt |
|
├─ .. |
|
└─ NNN.txt |
|
``` |
|
|
|
```{ .py .annotate } |
|
from ultralytics.data.converter import yolo_bbox2segment |
|
|
|
yolo_bbox2segment(#(1)! |
|
im_dir="path/to/images", |
|
save_dir=None, # saved to "labels-segment" in images directory |
|
sam_model="sam_b.pt" |
|
) |
|
``` |
|
|
|
1. Nothing returns from this function |
|
|
|
[Visit the `yolo_bbox2segment` reference page](../reference/data/converter.md#ultralytics.data.converter.yolo_bbox2segment) for more information regarding the function. |
|
|
|
### Convert Segments to Bounding Boxes |
|
|
|
If you have a dataset that uses the [segmentation dataset format](../datasets/segment/index.md) you can easily convert these into up-right (or horizontal) bounding boxes (`x y w h` format) with this function. |
|
|
|
```python |
|
from ultralytics.utils.ops import segments2boxes |
|
|
|
segments = np.array( |
|
[[805, 392, 797, 400, ..., 808, 714, 808, 392], |
|
[115, 398, 113, 400, ..., 150, 400, 149, 298], |
|
[267, 412, 265, 413, ..., 300, 413, 299, 412], |
|
] |
|
) |
|
|
|
segments2boxes([s.reshape(-1,2) for s in segments]) |
|
>>> array([[ 741.66, 631.12, 133.31, 479.25], |
|
[ 146.81, 649.69, 185.62, 502.88], |
|
[ 281.81, 636.19, 118.12, 448.88]], |
|
dtype=float32) # xywh bounding boxes |
|
``` |
|
|
|
To understand how this function works, visit the [reference page](../reference/utils/ops.md#ultralytics.utils.ops.segments2boxes) |
|
|
|
## Utilities |
|
|
|
### Image Compression |
|
|
|
Compresses a single image file to reduced size while preserving its aspect ratio and quality. If the input image is smaller than the maximum dimension, it will not be resized. |
|
|
|
```{ .py .annotate } |
|
from pathlib import Path |
|
from ultralytics.data.utils import compress_one_image |
|
|
|
for f in Path('path/to/dataset').rglob('*.jpg'): |
|
compress_one_image(f)#(1)! |
|
``` |
|
|
|
1. Nothing returns from this function |
|
|
|
### Auto-split Dataset |
|
|
|
Automatically split a dataset into `train`/`val`/`test` splits and save the resulting splits into `autosplit_*.txt` files. This function will use random sampling, which is not included when using [`fraction` argument for training](../modes/train.md#arguments). |
|
|
|
```{ .py .annotate } |
|
from ultralytics.data.utils import autosplit |
|
|
|
autosplit( #(1)! |
|
path="path/to/images", |
|
weights=(0.9, 0.1, 0.0), # (train, validation, test) fractional splits |
|
annotated_only=False # split only images with annotation file when True |
|
) |
|
``` |
|
|
|
1. Nothing returns from this function |
|
|
|
See the [Reference page](../reference/data/utils.md#ultralytics.data.utils.autosplit) for additional details on this function. |
|
|
|
### Segment-polygon to Binary Mask |
|
|
|
Convert a single polygon (as list) to a binary mask of the specified image size. Polygon in the form of `[N, 2]` with `N` as the number of `(x, y)` points defining the polygon contour. |
|
|
|
!!! warning |
|
|
|
`N` <b><u>must always</b></u> be even. |
|
|
|
```python |
|
import numpy as np |
|
from ultralytics.data.utils import polygon2mask |
|
|
|
imgsz = (1080, 810) |
|
polygon = np.array( |
|
[805, 392, 797, 400, ..., 808, 714, 808, 392], # (238, 2) |
|
) |
|
|
|
mask = polygon2mask( |
|
imgsz, # tuple |
|
[polygon], # input as list |
|
color=255, # 8-bit binary |
|
downsample_ratio=1 |
|
) |
|
``` |
|
|
|
## Bounding Boxes |
|
|
|
### Bounding Box (horizontal) Instances |
|
|
|
To manage bounding box data, the `Bboxes` class will help to convert between box coordinate formatting, scale box dimensions, calculate areas, include offsets, and more! |
|
|
|
```python |
|
from ultralytics.utils.instance import Bboxes |
|
|
|
boxes = Bboxes( |
|
bboxes=np.array( |
|
[[ 22.878, 231.27, 804.98, 756.83,], |
|
[ 48.552, 398.56, 245.35, 902.71,], |
|
[ 669.47, 392.19, 809.72, 877.04,], |
|
[ 221.52, 405.8, 344.98, 857.54,], |
|
[ 0, 550.53, 63.01, 873.44,], |
|
[ 0.0584, 254.46, 32.561, 324.87,]] |
|
), |
|
format="xyxy", |
|
) |
|
|
|
boxes.areas() |
|
>>> array([ 4.1104e+05, 99216, 68000, 55772, 20347, 2288.5]) |
|
boxes.convert("xywh") |
|
boxes.bboxes |
|
>>> array( |
|
[[ 413.93, 494.05, 782.1, 525.56], |
|
[ 146.95, 650.63, 196.8, 504.15], |
|
[ 739.6, 634.62, 140.25, 484.85], |
|
[ 283.25, 631.67, 123.46, 451.74], |
|
[ 31.505, 711.99, 63.01, 322.91], |
|
[ 16.31, 289.67, 32.503, 70.41]] |
|
) |
|
``` |
|
|
|
See the [`Bboxes` reference section](../reference/utils/instance.md#ultralytics.utils.instance.Bboxes) for more attributes and methods available. |
|
|
|
!!! tip |
|
Many of the following functions (and more) can be accessed using the [`Bboxes` class](#bounding-box-horizontal-instances) but if you prefer to work with the functions directly, see the next subsections on how to import these independently. |
|
|
|
### Scaling Boxes |
|
|
|
When scaling and image up or down, corresponding bounding box coordinates can be appropriately scaled to match using `ultralytics.utils.ops.scale_boxes`. |
|
|
|
```{ .py .annotate } |
|
import cv2 as cv |
|
import numpy as np |
|
from ultralytics.utils.ops import scale_boxes |
|
|
|
image = cv.imread("ultralytics/assets/bus.jpg") |
|
*(h, w), c = image.shape |
|
resized = cv.resize(image, None, (), fx=1.2, fy=1.2) |
|
*(new_h, new_w), _ = resized.shape |
|
|
|
xyxy_boxes = np.array( |
|
[[ 22.878, 231.27, 804.98, 756.83,], |
|
[ 48.552, 398.56, 245.35, 902.71,], |
|
[ 669.47, 392.19, 809.72, 877.04,], |
|
[ 221.52, 405.8, 344.98, 857.54,], |
|
[ 0, 550.53, 63.01, 873.44,], |
|
[ 0.0584, 254.46, 32.561, 324.87,]] |
|
) |
|
|
|
new_boxes = scale_boxes( |
|
img1_shape=(h, w), # original image dimensions |
|
boxes=xyxy_boxes, # boxes from original image |
|
img0_shape=(new_h, new_w), # resized image dimensions (scale to) |
|
ratio_pad=None, |
|
padding=False, |
|
xywh=False, |
|
) |
|
|
|
new_boxes#(1)! |
|
>>> array( |
|
[[ 27.454, 277.52, 965.98, 908.2], |
|
[ 58.262, 478.27, 294.42, 1083.3], |
|
[ 803.36, 470.63, 971.66, 1052.4], |
|
[ 265.82, 486.96, 413.98, 1029], |
|
[ 0, 660.64, 75.612, 1048.1], |
|
[ 0.0701, 305.35, 39.073, 389.84]] |
|
) |
|
``` |
|
|
|
1. Bounding boxes scaled for the new image size |
|
|
|
### Bounding Box Format Conversions |
|
|
|
#### XYXY → XYWH |
|
|
|
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner. |
|
|
|
```python |
|
import numpy as np |
|
from ultralytics.utils.ops import xyxy2xywh |
|
|
|
xyxy_boxes = np.array( |
|
[[ 22.878, 231.27, 804.98, 756.83,], |
|
[ 48.552, 398.56, 245.35, 902.71,], |
|
[ 669.47, 392.19, 809.72, 877.04,], |
|
[ 221.52, 405.8, 344.98, 857.54,], |
|
[ 0, 550.53, 63.01, 873.44,], |
|
[ 0.0584, 254.46, 32.561, 324.87,]] |
|
) |
|
xywh = xyxy2xywh(xyxy_boxes) |
|
|
|
xywh |
|
>>> array( |
|
[[ 413.93, 494.05, 782.1, 525.56], |
|
[ 146.95, 650.63, 196.8, 504.15], |
|
[ 739.6, 634.62, 140.25, 484.85], |
|
[ 283.25, 631.67, 123.46, 451.74], |
|
[ 31.505, 711.99, 63.01, 322.91], |
|
[ 16.31, 289.67, 32.503, 70.41]] |
|
) |
|
``` |
|
|
|
### All Bounding Box Conversions |
|
|
|
```python |
|
from ultralytics.utils.ops import xywh2xyxy |
|
from ultralytics.utils.ops import xywhn2xyxy # normalized → pixel |
|
from ultralytics.utils.ops import xyxy2xywhn # pixel → normalized |
|
from ultralytics.utils.ops import xywh2ltwh # xywh → top-left corner, w, h |
|
from ultralytics.utils.ops import xyxy2ltwh # xyxy → top-left corner, w, h |
|
from ultralytics.utils.ops import ltwh2xywh |
|
from ultralytics.utils.ops import ltwh2xyxy |
|
``` |
|
|
|
See docstring for each function or visit the `ultralytics.utils.ops` [reference page](../reference/utils/ops.md) to read more about each function. |
|
|
|
## Plotting |
|
|
|
### Drawing Annotations |
|
|
|
Ultralytics includes an Annotator class that can be used to annotate any kind of data. It's easiest to use with [object detection bounding boxes](../modes/predict.md#boxes), [pose key points](../modes/predict.md#keypoints), and [oriented bounding boxes](../modes/predict.md#obb). |
|
|
|
#### Horizontal Bounding Boxes |
|
|
|
```{ .py .annotate } |
|
import cv2 as cv |
|
import numpy as np |
|
from ultralytics.utils.plotting import Annotator, colors |
|
|
|
names { #(1)! |
|
0: "person", |
|
5: "bus", |
|
11: "stop sign", |
|
} |
|
|
|
image = cv.imread("ultralytics/assets/bus.jpg") |
|
ann = Annotator( |
|
image, |
|
line_width=None, # default auto-size |
|
font_size=None, # default auto-size |
|
font="Arial.ttf", # must be ImageFont compatible |
|
pil=False, # use PIL, otherwise uses OpenCV |
|
) |
|
|
|
xyxy_boxes = np.array( |
|
[[ 5, 22.878, 231.27, 804.98, 756.83,], # class-idx x1 y1 x2 y2 |
|
[ 0, 48.552, 398.56, 245.35, 902.71,], |
|
[ 0, 669.47, 392.19, 809.72, 877.04,], |
|
[ 0, 221.52, 405.8, 344.98, 857.54,], |
|
[ 0, 0, 550.53, 63.01, 873.44,], |
|
[11, 0.0584, 254.46, 32.561, 324.87,]] |
|
) |
|
|
|
for nb, box in enumerate(xyxy_boxes): |
|
c_idx, *box = box |
|
label = f"{str(nb).zfill(2)}:{names.get(int(c_idx))}" |
|
ann.box_label(box, label, color=colors(c_idx, bgr=True)) |
|
|
|
image_with_bboxes = ann.result() |
|
``` |
|
|
|
1. Names can be used from `model.names` when [working with detection results](../modes/predict.md#working-with-results) |
|
|
|
#### Oriented Bounding Boxes (OBB) |
|
```python |
|
import cv2 as cv |
|
import numpy as np |
|
from ultralytics.utils.plotting import Annotator, colors |
|
|
|
obb_names = {10: "small vehicle"} |
|
obb_image = cv.imread("datasets/dota8/images/train/P1142__1024__0___824.jpg") |
|
obb_boxes = np.array( |
|
[[ 0, 635, 560, 919, 719, 1087, 420, 803, 261,], # class-idx x1 y1 x2 y2 x3 y2 x4 y4 |
|
[ 0, 331, 19, 493, 260, 776, 70, 613, -171,], |
|
[ 9, 869, 161, 886, 147, 851, 101, 833, 115,] |
|
] |
|
) |
|
ann = Annotator( |
|
obb_image, |
|
line_width=None, # default auto-size |
|
font_size=None, # default auto-size |
|
font="Arial.ttf", # must be ImageFont compatible |
|
pil=False, # use PIL, otherwise uses OpenCV |
|
) |
|
for obb in obb_boxes: |
|
c_idx, *obb = obb |
|
obb = np.array(obb).reshape(-1, 4, 2).squeeze() |
|
label = f"{names.get(int(c_idx))}" |
|
ann.box_label( |
|
obb, |
|
label, |
|
color=colors(c_idx, True), |
|
rotated=True, |
|
) |
|
|
|
image_with_obb = ann.result() |
|
``` |
|
|
|
See the [`Annotator` Reference Page](../reference/utils/plotting.md#ultralytics.utils.plotting.Annotator) for additional insight. |
|
|
|
## Miscellaneous |
|
|
|
### Code Profiling |
|
|
|
Check duration for code to run/process either using `with` or as a decorator. |
|
|
|
```python |
|
from ultralytics.utils.ops import Profile |
|
|
|
with Profile(device=device) as dt: |
|
pass # operation to measure |
|
|
|
print(dt) |
|
>>> "Elapsed time is 9.5367431640625e-07 s" |
|
``` |
|
|
|
### Ultralytics Supported Formats |
|
|
|
Want or need to use the formats of [images or videos types supported](../modes/predict.md#image-and-video-formats) by Ultralytics programmatically? Use these constants if you need. |
|
|
|
```python |
|
from ultralytics.data.utils import IMG_FORMATS |
|
from ultralytics.data.utils import VID_FORMATS |
|
|
|
print(IMG_FORMATS) |
|
>>> ('bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm') |
|
``` |
|
|
|
### Make Divisible |
|
|
|
Calculates the nearest whole number to `x` to make evenly divisible when divided by `y`. |
|
|
|
```python |
|
from ultralytics.utils.ops import make_divisible |
|
|
|
make_divisible(7, 3) |
|
>>> 9 |
|
make_divisible(7, 2) |
|
>>> 8 |
|
```
|
|
|