Merge branch 'main' into yt_badges

yt_badges
Francesco Mattioli 5 months ago committed by GitHub
commit e2f92fe795
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 4
      .github/workflows/docs.yml
  2. 33
      docs/en/guides/parking-management.md
  3. 4
      docs/en/hub/inference-api.md
  4. 11
      docs/en/integrations/paddlepaddle.md
  5. 39
      docs/en/modes/predict.md
  6. 8
      docs/en/reference/nn/tasks.md
  7. 4
      docs/en/reference/utils/__init__.md
  8. 55
      docs/en/tasks/obb.md
  9. 62
      docs/overrides/main.html
  10. 109
      docs/overrides/stylesheets/style.css
  11. 5
      tests/test_python.py
  12. 2
      ultralytics/__init__.py
  13. 1
      ultralytics/cfg/__init__.py
  14. 6
      ultralytics/engine/exporter.py
  15. 22
      ultralytics/engine/model.py
  16. 86
      ultralytics/engine/results.py
  17. 3
      ultralytics/engine/validator.py
  18. 42
      ultralytics/hub/session.py
  19. 6
      ultralytics/nn/autobackend.py
  20. 192
      ultralytics/solutions/parking_management.py
  21. 58
      ultralytics/utils/__init__.py
  22. 2
      ultralytics/utils/benchmarks.py
  23. 5
      ultralytics/utils/checks.py
  24. 5
      ultralytics/utils/plotting.py
  25. 20
      ultralytics/utils/torch_utils.py

@ -48,7 +48,7 @@ jobs:
continue-on-error: true
run: ruff check --fix --unsafe-fixes --select D --ignore=D100,D104,D203,D205,D212,D213,D401,D406,D407,D413 .
- name: Update Docs Reference Section and Push Changes
if: github.event_name == 'pull_request_target'
continue-on-error: true
run: |
python docs/build_reference.py
git pull origin ${{ github.head_ref || github.ref }}
@ -68,7 +68,7 @@ jobs:
python docs/build_docs.py
- name: Commit and Push Docs changes
continue-on-error: true
if: always() && github.event_name == 'pull_request_target'
if: always()
run: |
git pull origin ${{ github.head_ref || github.ref }}
git add --update # only add updated files

@ -74,9 +74,6 @@ Parking management with [Ultralytics YOLOv8](https://github.com/ultralytics/ultr
from ultralytics import solutions
# Path to json file, that created with above point selection app
polygon_json_path = "bounding_boxes.json"
# Video capture
cap = cv2.VideoCapture("Path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
@ -86,22 +83,16 @@ Parking management with [Ultralytics YOLOv8](https://github.com/ultralytics/ultr
video_writer = cv2.VideoWriter("parking management.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
# Initialize parking management object
management = solutions.ParkingManagement(model_path="yolov8n.pt")
parking_manager = solutions.ParkingManagement(
model="yolov8n.pt", # path to model file
json_file="bounding_boxes.json", # path to parking annotations file
)
while cap.isOpened():
ret, im0 = cap.read()
if not ret:
break
json_data = management.parking_regions_extraction(polygon_json_path)
results = management.model.track(im0, persist=True, show=False)
if results[0].boxes.id is not None:
boxes = results[0].boxes.xyxy.cpu().tolist()
clss = results[0].boxes.cls.cpu().tolist()
management.process_data(json_data, im0, boxes, clss)
management.display_frames(im0)
im0 = parking_manager.process_data(im0)
video_writer.write(im0)
cap.release()
@ -111,14 +102,12 @@ Parking management with [Ultralytics YOLOv8](https://github.com/ultralytics/ultr
### Optional Arguments `ParkingManagement`
| Name | Type | Default | Description |
| ------------------------ | ------- | ----------------- | -------------------------------------- |
| `model_path` | `str` | `None` | Path to the YOLOv8 model. |
| `txt_color` | `tuple` | `(0, 0, 0)` | RGB color tuple for text. |
| `bg_color` | `tuple` | `(255, 255, 255)` | RGB color tuple for background. |
| `occupied_region_color` | `tuple` | `(0, 255, 0)` | RGB color tuple for occupied regions. |
| `available_region_color` | `tuple` | `(0, 0, 255)` | RGB color tuple for available regions. |
| `margin` | `int` | `10` | Margin for text display. |
| Name | Type | Default | Description |
| ------------------------ | ------- | ------------- | -------------------------------------------------------------- |
| `model` | `str` | `None` | Path to the YOLOv8 model. |
| `json_file` | `str` | `None` | Path to the JSON file, that have all parking coordinates data. |
| `occupied_region_color` | `tuple` | `(0, 0, 255)` | RGB color for occupied regions. |
| `available_region_color` | `tuple` | `(0, 255, 0)` | RGB color for available regions. |
### Arguments `model.track`

@ -139,7 +139,7 @@ The [Ultralytics HUB](https://www.ultralytics.com/hub) Inference API returns a J
results = model("image.jpg")
# Print image.jpg results in JSON format
print(results[0].tojson())
print(results[0].to_json())
```
=== "cURL"
@ -219,7 +219,7 @@ The [Ultralytics HUB](https://www.ultralytics.com/hub) Inference API returns a J
results = model("image.jpg")
# Print image.jpg results in JSON format
print(results[0].tojson())
print(results[0].to_json())
```
=== "cURL"

@ -8,6 +8,17 @@ keywords: YOLOv8, PaddlePaddle, export models, computer vision, deep learning, m
Bridging the gap between developing and deploying computer vision models in real-world scenarios with varying conditions can be difficult. PaddlePaddle makes this process easier with its focus on flexibility, performance, and its capability for parallel processing in distributed environments. This means you can use your YOLOv8 computer vision models on a wide variety of devices and platforms, from smartphones to cloud-based servers.
<p align="center">
<br>
<iframe loading="lazy" width="720" height="405" src="https://www.youtube.com/embed/c5eFrt2KuzY"
title="YouTube video player" frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
allowfullscreen>
</iframe>
<br>
<strong>Watch:</strong> How to Export Ultralytics YOLOv8 Models to PaddlePaddle Format | Key Features of PaddlePaddle Format
</p>
The ability to export to PaddlePaddle model format allows you to optimize your [Ultralytics YOLOv8](https://github.com/ultralytics/ultralytics) models for use within the PaddlePaddle framework. PaddlePaddle is known for facilitating industrial deployments and is a good choice for deploying computer vision applications in real-world settings across various domains.
## Why should you export to PaddlePaddle?

@ -328,9 +328,10 @@ Below are code examples for using each source type:
results = model(source, stream=True) # generator of Results objects
```
=== "Streams"
=== "Stream"
Use the stream mode to run inference on live video streams using RTSP, RTMP, TCP, or IP address protocols. If a single stream is provided, the model runs inference with a batch size of 1. For multiple streams, a `.streams` text file can be used to perform batched inference, where the batch size is determined by the number of streams provided (e.g., batch-size 8 for 8 streams).
Run inference on remote streaming sources using RTSP, RTMP, TCP and IP address protocols. If multiple streams are provided in a `*.streams` text file then batched inference will run, i.e. 8 streams will run at batch-size 8, otherwise single streams will run at batch-size 1.
```python
from ultralytics import YOLO
@ -338,15 +339,43 @@ Below are code examples for using each source type:
model = YOLO("yolov8n.pt")
# Single stream with batch-size 1 inference
source = "rtsp://example.com/media.mp4" # RTSP, RTMP, TCP or IP streaming address
source = "rtsp://example.com/media.mp4" # RTSP, RTMP, TCP, or IP streaming address
# Run inference on the source
results = model(source, stream=True) # generator of Results objects
```
For single stream usage, the batch size is set to 1 by default, allowing efficient real-time processing of the video feed.
=== "Multi-Stream"
To handle multiple video streams simultaneously, use a `.streams` text file containing the streaming sources. The model will run batched inference where the batch size equals the number of streams. This setup enables efficient processing of multiple feeds concurrently.
```python
from ultralytics import YOLO
# Load a pretrained YOLOv8n model
model = YOLO("yolov8n.pt")
# Multiple streams with batched inference (i.e. batch-size 8 for 8 streams)
source = "path/to/list.streams" # *.streams text file with one streaming address per row
# Multiple streams with batched inference (e.g., batch-size 8 for 8 streams)
source = "path/to/list.streams" # *.streams text file with one streaming address per line
# Run inference on the source
results = model(source, stream=True) # generator of Results objects
```
Example `.streams` text file:
```txt
rtsp://example.com/media1.mp4
rtsp://example.com/media2.mp4
rtmp://example2.com/live
tcp://192.168.1.100:554
...
```
Each row in the file represents a streaming source, allowing you to monitor and perform inference on several video streams at once.
## Inference Arguments
`model.predict()` accepts multiple arguments that can be passed at inference time to override defaults:

@ -47,6 +47,14 @@ keywords: Ultralytics, YOLO, nn tasks, DetectionModel, PoseModel, RTDETRDetectio
<br><br><hr><br>
## ::: ultralytics.nn.tasks.SafeClass
<br><br><hr><br>
## ::: ultralytics.nn.tasks.SafeUnpickler
<br><br><hr><br>
## ::: ultralytics.nn.tasks.temporary_modules
<br><br><hr><br>

@ -39,6 +39,10 @@ keywords: Ultralytics, utils, TQDM, Python, ML, Machine Learning utilities, YOLO
<br><br><hr><br>
## ::: ultralytics.utils.PersistentCacheDict
<br><br><hr><br>
## ::: ultralytics.utils.plt_settings
<br><br><hr><br>

@ -19,28 +19,17 @@ The output of an oriented object detector is a set of rotated bounding boxes tha
YOLOv8 OBB models use the `-obb` suffix, i.e. `yolov8n-obb.pt` and are pretrained on [DOTAv1](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/DOTAv1.yaml).
<table>
<tr>
<td align="center">
<iframe loading="lazy" width="720" height="405" src="https://www.youtube.com/embed/Z7Z9pHF8wJc"
title="YouTube video player" frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
allowfullscreen>
</iframe>
<br>
<strong>Watch:</strong> Object Detection using Ultralytics YOLOv8 Oriented Bounding Boxes (YOLOv8-OBB) <a href="https://www.youtube.com/watch?v=Z7Z9pHF8wJc" alt="YouTube Views"><img src="https://img.shields.io/youtube/views/Z7Z9pHF8wJc" alt="YouTube Views"></a>
</td>
<td align="center">
<iframe loading="lazy" width="720" height="405" src="https://www.youtube.com/embed/uZ7SymQfqKI"
title="YouTube video player" frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
allowfullscreen>
</iframe>
<br>
<strong>Watch:</strong> Object Detection with YOLOv8-OBB using Ultralytics HUB
</td>
</tr>
</table>
<p align="center">
<br>
<iframe loading="lazy" width="720" height="405" src="https://www.youtube.com/embed/Z7Z9pHF8wJc"
title="YouTube video player" frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
allowfullscreen>
</iframe>
<br>
<strong>Watch:</strong> Object Detection using Ultralytics YOLOv8 Oriented Bounding Boxes (YOLOv8-OBB)
</p>
## Visual Samples
@ -98,6 +87,17 @@ Train YOLOv8n-obb on the `dota8.yaml` dataset for 100 epochs at image size 640.
yolo obb train data=dota8.yaml model=yolov8n-obb.yaml pretrained=yolov8n-obb.pt epochs=100 imgsz=640
```
<p align="center">
<br>
<iframe loading="lazy" width="720" height="405" src="https://www.youtube.com/embed/uZ7SymQfqKI"
title="YouTube video player" frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
allowfullscreen>
</iframe>
<br>
<strong>Watch:</strong> How to Train Ultralytics YOLOv8-OBB (Oriented Bounding Boxes) Models on DOTA Dataset using Ultralytics HUB
</p>
### Dataset format
OBB dataset format can be found in detail in the [Dataset Guide](../datasets/obb/index.md).
@ -158,6 +158,17 @@ Use a trained YOLOv8n-obb model to run predictions on images.
yolo obb predict model=path/to/best.pt source='https://ultralytics.com/images/bus.jpg' # predict with custom model
```
<p align="center">
<br>
<iframe loading="lazy" width="720" height="405" src="https://www.youtube.com/embed/5XYdm5CYODA"
title="YouTube video player" frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
allowfullscreen>
</iframe>
<br>
<strong>Watch:</strong> How to Detect and Track Storage Tanks using Ultralytics YOLOv8-OBB | Oriented Bounding Boxes | DOTA
</p>
See full `predict` mode details in the [Predict](../modes/predict.md) page.
## Export

@ -1,30 +1,38 @@
<!--Ultralytics YOLO 🚀, AGPL-3.0 license-->
{% extends "base.html" %}
{% block announce %}
<a
href="https://github.com/ultralytics/ultralytics/releases/tag/v8.2.0"
target="_blank"
class="banner-wrapper"
>
<img
src="https://assets-global.website-files.com/646dd1f1a3703e451ba81ecc/66cb78d08408c438e54a6f2f_yolov82_release.avif"
loading="lazy"
alt="Ultralytics YOLOv8.2 Release"
class="banner-content desktop"
/>
<img
src="https://assets-global.website-files.com/646dd1f1a3703e451ba81ecc/66cb7a122db51139d6c0b4a8_yolov82_release_mobile.avif"
loading="lazy"
alt="Ultralytics YOLOv8.2 Release Mobile"
class="banner-content mobile"
/>
<img
src="https://assets-global.website-files.com/646dd1f1a3703e451ba81ecc/66cb779fc2ff285f3efceea1_arrow_effects.avif"
loading="lazy"
alt="Ultralytics YOLOv8.2 Release Arrow"
class="banner-arrow"
/>
</a>
{% extends "base.html" %} {% block announce %}
<div class="banner-wrapper">
<div class="banner-content-wrapper">
<p>YOLO Vision 2024 is here!</p>
<div class="banner-info-wrapper">
<img
src="https://assets-global.website-files.com/646dd1f1a3703e451ba81ecc/66e9a87cfc78245ffa51d6f0_w_yv24.svg"
loading="lazy"
width="20"
height="20"
alt="YOLO Vision 24"
/>
<p>September 27, 2024</p>
</div>
<div class="banner-info-wrapper">
<img
src="https://assets-global.website-files.com/646dd1f1a3703e451ba81ecc/66e9a87cdfbd25e409560ed8_l_yv24.svg"
loading="lazy"
width="20"
height="20"
alt="YOLO Vision 24"
/>
<p>Free hybrid event</p>
</div>
</div>
<div class="banner-button-wrapper">
<div class="banner-button-wrapper large">
<button
onclick="window.open('https://www.ultralytics.com/events/yolovision', '_blank')"
>
Register now
</button>
</div>
</div>
</div>
{% endblock %}

@ -51,7 +51,7 @@ div.highlight {
/* Banner (same as the one on the Ultralytics website) -------------------------------------------------------------- */
.md-banner {
background-image: url(https://assets-global.website-files.com/646dd1f1a3703e451ba81ecc/6627a0cab2de939ad35939ed_banner_82.webp);
background-image: url(https://assets-global.website-files.com/646dd1f1a3703e451ba81ecc/66e9a211bf6831d112fd6ce3_banner_yv24.avif);
background-size: cover;
background-position: center;
}
@ -61,44 +61,109 @@ div.highlight {
margin-bottom: 0 !important;
}
.banner-wrapper {
.banner-wrapper,
.banner-wrapper > .banner-content-wrapper,
.banner-wrapper > .banner-content-wrapper > .banner-info-wrapper {
display: flex;
justify-content: center;
align-items: center;
}
.banner-wrapper,
.banner-wrapper > .banner-content-wrapper {
flex-direction: column;
}
.banner-wrapper {
justify-content: space-between;
gap: 16px;
padding: 16px;
}
.banner-wrapper > .banner-content-wrapper,
.banner-wrapper > .banner-content-wrapper > .banner-info-wrapper {
justify-content: center;
}
.banner-wrapper > .banner-content-wrapper {
gap: 8px;
}
.banner-wrapper > .banner-content-wrapper > .banner-info-wrapper {
gap: 4px;
}
height: 64px;
.banner-wrapper > .banner-content-wrapper > p,
.banner-wrapper > .banner-content-wrapper > .banner-info-wrapper > p {
margin: 0;
}
.banner-wrapper > .banner-content-wrapper > p {
font-size: 20px;
font-weight: 500;
}
.banner-wrapper > .banner-content-wrapper > .banner-info-wrapper > p,
.banner-wrapper > .banner-button-wrapper > .banner-button-wrapper > button {
font-size: 14px;
}
.banner-wrapper > .banner-content-wrapper > .banner-info-wrapper > p {
color: #f3f3f3;
}
overflow: hidden;
.banner-wrapper > .banner-button-wrapper,
.banner-wrapper > .banner-button-wrapper > .banner-button-wrapper,
.banner-wrapper > .banner-button-wrapper > .banner-button-wrapper > button {
border-radius: 100px;
}
.banner-content {
max-height: 64px;
.banner-wrapper > .banner-button-wrapper,
.banner-wrapper > .banner-button-wrapper > .banner-button-wrapper {
padding: 2px;
background-color: rgba(222, 255, 56, 0.2);
}
.banner-content.desktop {
display: none;
.banner-wrapper > .banner-button-wrapper > .banner-button-wrapper.large {
padding: 4px;
}
.banner-arrow {
display: none;
max-height: 80px;
margin-left: -16px;
transition: transform ease-in-out 0.5s;
.banner-wrapper > .banner-button-wrapper > .banner-button-wrapper > button {
cursor: pointer;
min-width: 132px;
padding: 10px;
font-weight: 500;
color: #111f68;
background-color: rgb(222, 255, 56);
}
.banner-wrapper:hover > .banner-arrow {
transform: translateX(8px);
.banner-wrapper
> .banner-button-wrapper
> .banner-button-wrapper
> button:hover {
background-color: rgba(222, 255, 56, 0.85);
}
@media screen and (min-width: 768px) {
.banner-content.mobile {
display: none;
.banner-wrapper,
.banner-wrapper > .banner-content-wrapper {
flex-direction: row;
}
.banner-content.desktop {
display: revert;
.banner-wrapper {
gap: 32px;
padding: 12px;
}
.banner-arrow {
display: revert;
.banner-wrapper > .banner-content-wrapper {
gap: 24px;
margin: 0 auto;
}
}
/* Banner (same as the one on the Ultralytics website) -------------------------------------------------------------- */

@ -269,7 +269,10 @@ def test_results(model):
r = r.to(device="cpu", dtype=torch.float32)
r.save_txt(txt_file=TMP / "runs/tests/label.txt", save_conf=True)
r.save_crop(save_dir=TMP / "runs/tests/crops/")
r.tojson(normalize=True)
r.to_json(normalize=True)
r.to_df(decimals=3)
r.to_csv()
r.to_xml()
r.plot(pil=True)
r.plot(conf=True, boxes=True)
print(r, len(r), r.path) # print after methods

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.95"
__version__ = "8.2.98"
import os

@ -712,6 +712,7 @@ def entrypoint(debug=""):
"cfg": lambda: yaml_print(DEFAULT_CFG_PATH),
"hub": lambda: handle_yolo_hub(args[1:]),
"login": lambda: handle_yolo_hub(args),
"logout": lambda: handle_yolo_hub(args),
"copy-cfg": copy_default_cfg,
"explorer": lambda: handle_explorer(args[1:]),
"streamlit-predict": lambda: handle_streamlit_inference(),

@ -95,9 +95,7 @@ from ultralytics.utils.torch_utils import TORCH_1_13, get_latest_opset, select_d
def export_formats():
"""YOLOv8 export formats."""
import pandas # scope for faster 'import ultralytics'
"""Ultralytics YOLO export formats."""
x = [
["PyTorch", "-", ".pt", True, True],
["TorchScript", "torchscript", ".torchscript", True, True],
@ -113,7 +111,7 @@ def export_formats():
["PaddlePaddle", "paddle", "_paddle_model", True, True],
["NCNN", "ncnn", "_ncnn_model", True, True],
]
return pandas.DataFrame(x, columns=["Format", "Argument", "Suffix", "CPU", "GPU"])
return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU"], zip(*x)))
def gd_outputs(gd):

@ -206,33 +206,21 @@ class Model(nn.Module):
Check if the provided model is an Ultralytics HUB model.
This static method determines whether the given model string represents a valid Ultralytics HUB model
identifier. It checks for three possible formats: a full HUB URL, an API key and model ID combination,
or a standalone model ID.
identifier.
Args:
model (str): The model identifier to check. This can be a URL, an API key and model ID
combination, or a standalone model ID.
model (str): The model string to check.
Returns:
(bool): True if the model is a valid Ultralytics HUB model, False otherwise.
Examples:
>>> Model.is_hub_model("https://hub.ultralytics.com/models/example_model")
>>> Model.is_hub_model("https://hub.ultralytics.com/models/MODEL")
True
>>> Model.is_hub_model("api_key_example_model_id")
True
>>> Model.is_hub_model("example_model_id")
True
>>> Model.is_hub_model("not_a_hub_model.pt")
>>> Model.is_hub_model("yolov8n.pt")
False
"""
return any(
(
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL
)
)
return model.startswith(f"{HUB_WEB_ROOT}/models/")
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
"""

@ -14,6 +14,7 @@ import torch
from ultralytics.data.augment import LetterBox
from ultralytics.utils import LOGGER, SimpleClass, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.plotting import Annotator, colors, save_one_box
from ultralytics.utils.torch_utils import smart_inference_mode
@ -818,7 +819,90 @@ class Results(SimpleClass):
return results
def to_df(self, normalize=False, decimals=5):
"""
Converts detection results to a Pandas Dataframe.
This method converts the detection results into Pandas Dataframe format. It includes information
about detected objects such as bounding boxes, class names, confidence scores, and optionally
segmentation masks and keypoints.
Args:
normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions.
If True, coordinates will be returned as float values between 0 and 1. Defaults to False.
decimals (int): Number of decimal places to round the output values to. Defaults to 5.
Returns:
(DataFrame): A Pandas Dataframe containing all the information in results in an organized way.
Examples:
>>> results = model("path/to/image.jpg")
>>> df_result = results[0].to_df()
>>> print(df_result)
"""
import pandas as pd
return pd.DataFrame(self.summary(normalize=normalize, decimals=decimals))
def to_csv(self, normalize=False, decimals=5, *args, **kwargs):
"""
Converts detection results to a CSV format.
This method serializes the detection results into a CSV format. It includes information
about detected objects such as bounding boxes, class names, confidence scores, and optionally
segmentation masks and keypoints.
Args:
normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions.
If True, coordinates will be returned as float values between 0 and 1. Defaults to False.
decimals (int): Number of decimal places to round the output values to. Defaults to 5.
*args (Any): Variable length argument list to be passed to pandas.DataFrame.to_csv().
**kwargs (Any): Arbitrary keyword arguments to be passed to pandas.DataFrame.to_csv().
Returns:
(str): CSV containing all the information in results in an organized way.
Examples:
>>> results = model("path/to/image.jpg")
>>> csv_result = results[0].to_csv()
>>> print(csv_result)
"""
return self.to_df(normalize=normalize, decimals=decimals).to_csv(*args, **kwargs)
def to_xml(self, normalize=False, decimals=5, *args, **kwargs):
"""
Converts detection results to XML format.
This method serializes the detection results into an XML format. It includes information
about detected objects such as bounding boxes, class names, confidence scores, and optionally
segmentation masks and keypoints.
Args:
normalize (bool): Whether to normalize the bounding box coordinates by the image dimensions.
If True, coordinates will be returned as float values between 0 and 1. Defaults to False.
decimals (int): Number of decimal places to round the output values to. Defaults to 5.
*args (Any): Variable length argument list to be passed to pandas.DataFrame.to_xml().
**kwargs (Any): Arbitrary keyword arguments to be passed to pandas.DataFrame.to_xml().
Returns:
(str): An XML string containing all the information in results in an organized way.
Examples:
>>> results = model("path/to/image.jpg")
>>> xml_result = results[0].to_xml()
>>> print(xml_result)
"""
check_requirements("lxml")
df = self.to_df(normalize=normalize, decimals=decimals)
return '<?xml version="1.0" encoding="utf-8"?>\n<root></root>' if df.empty else df.to_xml(*args, **kwargs)
def tojson(self, normalize=False, decimals=5):
"""Deprecated version of to_json()."""
LOGGER.warning("WARNING ⚠ 'result.tojson()' is deprecated, replace with 'result.to_json()'.")
return self.to_json(normalize, decimals)
def to_json(self, normalize=False, decimals=5):
"""
Converts detection results to JSON format.
@ -836,7 +920,7 @@ class Results(SimpleClass):
Examples:
>>> results = model("path/to/image.jpg")
>>> json_result = results[0].tojson()
>>> json_result = results[0].to_json()
>>> print(json_result)
Notes:

@ -110,7 +110,8 @@ class BaseValidator:
if self.training:
self.device = trainer.device
self.data = trainer.data
self.args.half = self.device.type != "cpu" # force FP16 val during training
# force FP16 val during training
self.args.half = self.device.type != "cpu" and trainer.amp
model = trainer.ema.ema or trainer.model
model = model.half() if self.args.half else model.float()
# self.model = model

@ -5,6 +5,7 @@ import threading
import time
from http import HTTPStatus
from pathlib import Path
from urllib.parse import parse_qs, urlparse
import requests
@ -77,7 +78,6 @@ class HUBTrainingSession:
if not session.client.authenticated:
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
LOGGER.warning(f"{PREFIX}WARNING ⚠ Login to Ultralytics HUB with 'yolo hub login API_KEY'.")
exit()
return None
if args and not identifier.startswith(f"{HUB_WEB_ROOT}/models/"): # not a HUB model URL
session.create_model(args)
@ -96,7 +96,8 @@ class HUBTrainingSession:
self.model_url = f"{HUB_WEB_ROOT}/models/{self.model.id}"
if self.model.is_trained():
print(emojis(f"Loading trained HUB model {self.model_url} 🚀"))
self.model_file = self.model.get_weights_url("best")
url = self.model.get_weights_url("best") # download URL with auth
self.model_file = checks.check_file(url, download_dir=Path(SETTINGS["weights_dir"]) / "hub" / self.model.id)
return
# Set training args and start heartbeats for HUB to monitor agent
@ -146,9 +147,8 @@ class HUBTrainingSession:
Parses the given identifier to determine the type of identifier and extract relevant components.
The method supports different identifier formats:
- A HUB URL, which starts with HUB_WEB_ROOT followed by '/models/'
- An identifier containing an API key and a model ID separated by an underscore
- An identifier that is solely a model ID of a fixed length
- A HUB model URL https://hub.ultralytics.com/models/MODEL
- A HUB model URL with API Key https://hub.ultralytics.com/models/MODEL?api_key=APIKEY
- A local filename that ends with '.pt' or '.yaml'
Args:
@ -160,32 +160,16 @@ class HUBTrainingSession:
Raises:
HUBModelError: If the identifier format is not recognized.
"""
# Initialize variables
api_key, model_id, filename = None, None, None
# Check if identifier is a HUB URL
if identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
# Extract the model_id after the HUB_WEB_ROOT URL
model_id = identifier.split(f"{HUB_WEB_ROOT}/models/")[-1]
if Path(identifier).suffix in {".pt", ".yaml"}:
filename = identifier
elif identifier.startswith(f"{HUB_WEB_ROOT}/models/"):
parsed_url = urlparse(identifier)
model_id = Path(parsed_url.path).stem # handle possible final backslash robustly
query_params = parse_qs(parsed_url.query) # dictionary, i.e. {"api_key": ["API_KEY_HERE"]}
api_key = query_params.get("api_key", [None])[0]
else:
# Split the identifier based on underscores only if it's not a HUB URL
parts = identifier.split("_")
# Check if identifier is in the format of API key and model ID
if len(parts) == 2 and len(parts[0]) == 42 and len(parts[1]) == 20:
api_key, model_id = parts
# Check if identifier is a single model ID
elif len(parts) == 1 and len(parts[0]) == 20:
model_id = parts[0]
# Check if identifier is a local filename
elif identifier.endswith(".pt") or identifier.endswith(".yaml"):
filename = identifier
else:
raise HUBModelError(
f"model='{identifier}' could not be parsed. Check format is correct. "
f"Supported formats are Ultralytics HUB URL, apiKey_modelId, modelId, local pt or yaml file."
)
raise HUBModelError(f"model='{identifier} invalid, correct format is {HUB_WEB_ROOT}/models/MODEL_ID")
return api_key, model_id, filename
def _set_train_args(self):

@ -398,8 +398,8 @@ class AutoBackend(nn.Module):
from ultralytics.engine.exporter import export_formats
raise TypeError(
f"model='{w}' is not a supported model format. "
f"See https://docs.ultralytics.com/modes/predict for help.\n\n{export_formats()}"
f"model='{w}' is not a supported model format. Ultralytics supports: {export_formats()['Format']}\n"
f"See https://docs.ultralytics.com/modes/predict for help."
)
# Load external metadata YAML
@ -653,7 +653,7 @@ class AutoBackend(nn.Module):
"""
from ultralytics.engine.exporter import export_formats
sf = list(export_formats().Suffix) # export suffixes
sf = export_formats()["Suffix"] # export suffixes
if not is_url(p) and not isinstance(p, str):
check_suffix(p, sf) # checks
name = Path(p).name

@ -42,10 +42,10 @@ class ParkingPtsSelection:
self.image_path = None
self.image = None
self.canvas_image = None
self.bounding_boxes = []
self.rg_data = [] # region coordinates
self.current_box = []
self.img_width = 0
self.img_height = 0
self.imgw = 0 # image width
self.imgh = 0 # image height
# Constants
self.canvas_max_width = 1280
@ -64,17 +64,17 @@ class ParkingPtsSelection:
return
self.image = Image.open(self.image_path)
self.img_width, self.img_height = self.image.size
self.imgw, self.imgh = self.image.size
# Calculate the aspect ratio and resize image
aspect_ratio = self.img_width / self.img_height
aspect_ratio = self.imgw / self.imgh
if aspect_ratio > 1:
# Landscape orientation
canvas_width = min(self.canvas_max_width, self.img_width)
canvas_width = min(self.canvas_max_width, self.imgw)
canvas_height = int(canvas_width / aspect_ratio)
else:
# Portrait orientation
canvas_height = min(self.canvas_max_height, self.img_height)
canvas_height = min(self.canvas_max_height, self.imgh)
canvas_width = int(canvas_height * aspect_ratio)
# Check if canvas is already initialized
@ -90,46 +90,34 @@ class ParkingPtsSelection:
self.canvas.bind("<Button-1>", self.on_canvas_click)
# Reset bounding boxes and current box
self.bounding_boxes = []
self.rg_data = []
self.current_box = []
def on_canvas_click(self, event):
"""Handle mouse clicks on canvas to create points for bounding boxes."""
self.current_box.append((event.x, event.y))
x0, y0 = event.x - 3, event.y - 3
x1, y1 = event.x + 3, event.y + 3
self.canvas.create_oval(x0, y0, x1, y1, fill="red")
self.canvas.create_oval(event.x - 3, event.y - 3, event.x + 3, event.y + 3, fill="red")
if len(self.current_box) == 4:
self.bounding_boxes.append(self.current_box)
self.draw_bounding_box(self.current_box)
self.rg_data.append(self.current_box)
[
self.canvas.create_line(self.current_box[i], self.current_box[(i + 1) % 4], fill="blue", width=2)
for i in range(4)
]
self.current_box = []
def draw_bounding_box(self, box):
"""
Draw bounding box on canvas.
Args:
box (list): Bounding box data
"""
for i in range(4):
x1, y1 = box[i]
x2, y2 = box[(i + 1) % 4]
self.canvas.create_line(x1, y1, x2, y2, fill="blue", width=2)
def remove_last_bounding_box(self):
"""Remove the last drawn bounding box from canvas."""
from tkinter import messagebox # scope for multi-environment compatibility
if self.bounding_boxes:
self.bounding_boxes.pop() # Remove the last bounding box
if self.rg_data:
self.rg_data.pop() # Remove the last bounding box
self.canvas.delete("all") # Clear the canvas
self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image) # Redraw the image
# Redraw all bounding boxes
for box in self.bounding_boxes:
self.draw_bounding_box(box)
for box in self.rg_data:
[self.canvas.create_line(box[i], box[(i + 1) % 4], fill="blue", width=2) for i in range(4)]
messagebox.showinfo("Success", "Last bounding box removed.")
else:
messagebox.showwarning("Warning", "No bounding boxes to remove.")
@ -138,19 +126,19 @@ class ParkingPtsSelection:
"""Saves rescaled bounding boxes to 'bounding_boxes.json' based on image-to-canvas size ratio."""
from tkinter import messagebox # scope for multi-environment compatibility
canvas_width, canvas_height = self.canvas.winfo_width(), self.canvas.winfo_height()
width_scaling_factor = self.img_width / canvas_width
height_scaling_factor = self.img_height / canvas_height
bounding_boxes_data = []
for box in self.bounding_boxes:
rescaled_box = []
rg_data = [] # regions data
for box in self.rg_data:
rs_box = [] # rescaled box list
for x, y in box:
rescaled_x = int(x * width_scaling_factor)
rescaled_y = int(y * height_scaling_factor)
rescaled_box.append((rescaled_x, rescaled_y))
bounding_boxes_data.append({"points": rescaled_box})
rs_box.append(
(
int(x * self.imgw / self.canvas.winfo_width()), # width scaling
int(y * self.imgh / self.canvas.winfo_height()),
)
) # height scaling
rg_data.append({"points": rs_box})
with open("bounding_boxes.json", "w") as f:
json.dump(bounding_boxes_data, f, indent=4)
json.dump(rg_data, f, indent=4)
messagebox.showinfo("Success", "Bounding boxes saved to bounding_boxes.json")
@ -160,102 +148,85 @@ class ParkingManagement:
def __init__(
self,
model_path,
txt_color=(0, 0, 0),
bg_color=(255, 255, 255),
occupied_region_color=(0, 255, 0),
available_region_color=(0, 0, 255),
margin=10,
model, # Ultralytics YOLO model file path
json_file, # Parking management annotation file created from Parking Annotator
occupied_region_color=(0, 0, 255), # occupied region color
available_region_color=(0, 255, 0), # available region color
):
"""
Initializes the parking management system with a YOLOv8 model and visualization settings.
Args:
model_path (str): Path to the YOLOv8 model.
txt_color (tuple): RGB color tuple for text.
bg_color (tuple): RGB color tuple for background.
model (str): Path to the YOLOv8 model.
json_file (str): file that have all parking slot points data
occupied_region_color (tuple): RGB color tuple for occupied regions.
available_region_color (tuple): RGB color tuple for available regions.
margin (int): Margin for text display.
"""
# Model path and initialization
self.model_path = model_path
self.model = self.load_model()
# Labels dictionary
self.labels_dict = {"Occupancy": 0, "Available": 0}
# Visualization details
self.margin = margin
self.bg_color = bg_color
self.txt_color = txt_color
self.occupied_region_color = occupied_region_color
self.available_region_color = available_region_color
self.window_name = "Ultralytics YOLOv8 Parking Management System"
# Check if environment supports imshow
self.env_check = check_imshow(warn=True)
def load_model(self):
"""Load the Ultralytics YOLO model for inference and analytics."""
# Model initialization
from ultralytics import YOLO
return YOLO(self.model_path)
@staticmethod
def parking_regions_extraction(json_file):
"""
Extract parking regions from json file.
self.model = YOLO(model)
Args:
json_file (str): file that have all parking slot points
"""
# Load JSON data
with open(json_file) as f:
return json.load(f)
self.json_data = json.load(f)
def process_data(self, json_data, im0, boxes, clss):
self.pr_info = {"Occupancy": 0, "Available": 0} # dictionary for parking information
self.occ = occupied_region_color
self.arc = available_region_color
self.env_check = check_imshow(warn=True) # check if environment supports imshow
def process_data(self, im0):
"""
Process the model data for parking lot management.
Args:
json_data (str): json data for parking lot management
im0 (ndarray): inference image
boxes (list): bounding boxes data
clss (list): bounding boxes classes list
Returns:
filled_slots (int): total slots that are filled in parking lot
empty_slots (int): total slots that are available in parking lot
"""
annotator = Annotator(im0)
empty_slots, filled_slots = len(json_data), 0
for region in json_data:
points_array = np.array(region["points"], dtype=np.int32).reshape((-1, 1, 2))
region_occupied = False
results = self.model.track(im0, persist=True, show=False) # object tracking
for box, cls in zip(boxes, clss):
x_center = int((box[0] + box[2]) / 2)
y_center = int((box[1] + box[3]) / 2)
text = f"{self.model.names[int(cls)]}"
es, fs = len(self.json_data), 0 # empty slots, filled slots
annotator = Annotator(im0) # init annotator
# extract tracks data
if results[0].boxes.id is None:
self.display_frames(im0)
return im0
boxes = results[0].boxes.xyxy.cpu().tolist()
clss = results[0].boxes.cls.cpu().tolist()
for region in self.json_data:
# Convert points to a NumPy array with the correct dtype and reshape properly
pts_array = np.array(region["points"], dtype=np.int32).reshape((-1, 1, 2))
rg_occupied = False # occupied region initialization
for box, cls in zip(boxes, clss):
xc = int((box[0] + box[2]) / 2)
yc = int((box[1] + box[3]) / 2)
annotator.display_objects_labels(
im0, text, self.txt_color, self.bg_color, x_center, y_center, self.margin
im0, self.model.names[int(cls)], (104, 31, 17), (255, 255, 255), xc, yc, 10
)
dist = cv2.pointPolygonTest(points_array, (x_center, y_center), False)
dist = cv2.pointPolygonTest(pts_array, (xc, yc), False)
if dist >= 0:
region_occupied = True
rg_occupied = True
break
if rg_occupied:
fs += 1
es -= 1
# Plotting regions
color = self.occ if rg_occupied else self.arc
cv2.polylines(im0, [pts_array], isClosed=True, color=color, thickness=2)
color = self.occupied_region_color if region_occupied else self.available_region_color
cv2.polylines(im0, [points_array], isClosed=True, color=color, thickness=2)
if region_occupied:
filled_slots += 1
empty_slots -= 1
self.pr_info["Occupancy"] = fs
self.pr_info["Available"] = es
self.labels_dict["Occupancy"] = filled_slots
self.labels_dict["Available"] = empty_slots
annotator.display_analytics(im0, self.pr_info, (104, 31, 17), (255, 255, 255), 10)
annotator.display_analytics(im0, self.labels_dict, self.txt_color, self.bg_color, self.margin)
self.display_frames(im0)
return im0
def display_frames(self, im0):
"""
@ -265,8 +236,7 @@ class ParkingManagement:
im0 (ndarray): inference image
"""
if self.env_check:
cv2.namedWindow(self.window_name)
cv2.imshow(self.window_name, im0)
cv2.imshow("Ultralytics Parking Manager", im0)
# Break Window
if cv2.waitKey(1) & 0xFF == ord("q"):
return

@ -3,6 +3,7 @@
import contextlib
import importlib.metadata
import inspect
import json
import logging.config
import os
import platform
@ -14,6 +15,7 @@ import time
import urllib
import uuid
from pathlib import Path
from threading import Lock
from types import SimpleNamespace
from typing import Union
@ -1136,6 +1138,61 @@ class SettingsManager(dict):
self.save()
class PersistentCacheDict(dict):
"""A thread-safe dictionary that persists data to a JSON file for caching purposes."""
def __init__(self, file_path=USER_CONFIG_DIR / "persistent_cache.json"):
"""Initializes a thread-safe persistent cache dictionary with a specified file path for storage."""
super().__init__()
self.file_path = Path(file_path)
self.lock = Lock()
self._load()
def _load(self):
"""Load the persistent cache from a JSON file into the dictionary, handling errors gracefully."""
try:
if self.file_path.exists():
with open(self.file_path) as f:
self.update(json.load(f))
except json.JSONDecodeError:
print(f"Error decoding JSON from {self.file_path}. Starting with an empty cache.")
except Exception as e:
print(f"Error reading from {self.file_path}: {e}")
def _save(self):
"""Save the current state of the cache dictionary to a JSON file, ensuring thread safety."""
try:
self.file_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.file_path, "w") as f:
json.dump(dict(self), f, indent=2)
except Exception as e:
print(f"Error writing to {self.file_path}: {e}")
def __setitem__(self, key, value):
"""Store a key-value pair in the cache and persist the updated cache to disk."""
with self.lock:
super().__setitem__(key, value)
self._save()
def __delitem__(self, key):
"""Remove an item from the PersistentCacheDict and update the persistent storage."""
with self.lock:
super().__delitem__(key)
self._save()
def update(self, *args, **kwargs):
"""Update the dictionary with key-value pairs from other mappings or iterables, ensuring thread safety."""
with self.lock:
super().update(*args, **kwargs)
self._save()
def clear(self):
"""Clears all entries from the persistent cache dictionary, ensuring thread safety."""
with self.lock:
super().clear()
self._save()
def deprecation_warn(arg, new_arg):
"""Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
LOGGER.warning(
@ -1171,6 +1228,7 @@ def vscode_msg(ext="ultralytics.ultralytics-snippets") -> str:
# Check first-install steps
PREFIX = colorstr("Ultralytics: ")
SETTINGS = SettingsManager() # initialize settings
PERSISTENT_CACHE = PersistentCacheDict() # initialize persistent cache
DATASETS_DIR = Path(SETTINGS["datasets_dir"]) # global datasets directory
WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory
RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory

@ -85,7 +85,7 @@ def benchmark(
y = []
t0 = time.time()
for i, (name, format, suffix, cpu, gpu) in export_formats().iterrows(): # index, (name, format, suffix, CPU, GPU)
for i, (name, format, suffix, cpu, gpu) in enumerate(zip(*export_formats().values())):
emoji, filename = "", None # export defaults
try:
# Checks

@ -656,9 +656,10 @@ def check_amp(model):
def amp_allclose(m, im):
"""All close FP32 vs AMP results."""
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
batch = [im] * 8
a = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # FP32 inference
with autocast(enabled=True):
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
b = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # AMP inference
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance

@ -419,7 +419,7 @@ class Annotator:
# Convert im back to PIL and update draw
self.fromarray(self.im)
def kpts(self, kpts, shape=(640, 640), radius=5, kpt_line=True, conf_thres=0.25, kpt_color=None):
def kpts(self, kpts, shape=(640, 640), radius=None, kpt_line=True, conf_thres=0.25, kpt_color=None):
"""
Plot keypoints on the image.
@ -436,6 +436,7 @@ class Annotator:
- Modifies self.im in-place.
- If self.pil is True, converts image to numpy array and back to PIL.
"""
radius = radius if radius is not None else self.lw
if self.pil:
# Convert to numpy first
self.im = np.asarray(self.im).copy()
@ -471,7 +472,7 @@ class Annotator:
pos1,
pos2,
kpt_color or self.limb_color[i].tolist(),
thickness=2,
thickness=int(np.ceil(self.lw / 2)),
lineType=cv2.LINE_AA,
)
if self.pil:

@ -110,15 +110,17 @@ def autocast(enabled: bool, device: str = "cuda"):
def get_cpu_info():
"""Return a string with system CPU information, i.e. 'Apple M2'."""
with contextlib.suppress(Exception):
import cpuinfo # pip install py-cpuinfo
from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference (not all keys always available)
info = cpuinfo.get_cpu_info() # info dict
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
return string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
if "cpu_info" not in PERSISTENT_CACHE:
with contextlib.suppress(Exception):
import cpuinfo # pip install py-cpuinfo
return "unknown"
k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference
info = cpuinfo.get_cpu_info() # info dict
string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown")
PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "")
return PERSISTENT_CACHE.get("cpu_info", "unknown")
def select_device(device="", batch=0, newline=False, verbose=True):
@ -247,7 +249,7 @@ def fuse_conv_and_bn(conv, bn):
)
# Prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_conv = conv.weight.view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
@ -278,7 +280,7 @@ def fuse_deconv_and_bn(deconv, bn):
)
# Prepare filters
w_deconv = deconv.weight.clone().view(deconv.out_channels, -1)
w_deconv = deconv.weight.view(deconv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape))

Loading…
Cancel
Save