Merge branch 'main' into benchmark-format-args

benchmark-format-args
Ultralytics Assistant 2 months ago committed by GitHub
commit fe9c04c708
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 44
      README.md
  2. 44
      README.zh-CN.md
  3. 18
      docs/en/modes/predict.md
  4. 114
      docs/en/modes/track.md
  5. 4
      docs/en/yolov5/quickstart_tutorial.md
  6. 2
      docs/mkdocs_github_authors.yaml
  7. 5
      tests/__init__.py
  8. 42
      tests/test_python.py
  9. 2
      ultralytics/__init__.py
  10. 11
      ultralytics/engine/trainer.py
  11. 4
      ultralytics/nn/tasks.py
  12. 13
      ultralytics/utils/torch_utils.py

@ -87,14 +87,25 @@ YOLOv8 may also be used directly in a Python environment, and accepts the same [
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8n.yaml") # build a new model from scratch
model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training)
# Use the model
model.train(data="coco8.yaml", epochs=3) # train the model
metrics = model.val() # evaluate model performance on the validation set
results = model("https://ultralytics.com/images/bus.jpg") # predict on an image
path = model.export(format="onnx") # export the model to ONNX format
model = YOLO("yolov8n.pt")
# Train the model
train_results = model.train(
data="coco8.yaml", # path to dataset YAML
epochs=100, # number of training epochs
imgsz=640, # training image size
device="cpu", # device to run on, i.e. device=0 or device=0,1,2,3 or device=cpu
)
# Evaluate model performance on the validation set
metrics = model.val()
# Perform object detection on an image
results = model("path/to/image.jpg")
results[0].show()
# Export the model to ONNX format
path = model.export(format="onnx") # return path to exported model
```
See YOLOv8 [Python Docs](https://docs.ultralytics.com/usage/python/) for more examples.
@ -139,23 +150,6 @@ See [Detection Docs](https://docs.ultralytics.com/tasks/detect/) for usage examp
</details>
<details><summary>Detection (Open Image V7)</summary>
See [Detection Docs](https://docs.ultralytics.com/tasks/detect/) for usage examples with these models trained on [Open Image V7](https://docs.ultralytics.com/datasets/detect/open-images-v7/), which include 600 pre-trained classes.
| Model | size<br><sup>(pixels) | mAP<sup>val<br>50-95 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>A100 TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) |
| ----------------------------------------------------------------------------------------- | --------------------- | -------------------- | ------------------------------ | ----------------------------------- | ------------------ | ----------------- |
| [YOLOv8n](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-oiv7.pt) | 640 | 18.4 | 142.4 | 1.21 | 3.5 | 10.5 |
| [YOLOv8s](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8s-oiv7.pt) | 640 | 27.7 | 183.1 | 1.40 | 11.4 | 29.7 |
| [YOLOv8m](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8m-oiv7.pt) | 640 | 33.6 | 408.5 | 2.26 | 26.2 | 80.6 |
| [YOLOv8l](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8l-oiv7.pt) | 640 | 34.9 | 596.9 | 2.43 | 44.1 | 167.4 |
| [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8x-oiv7.pt) | 640 | 36.3 | 860.6 | 3.56 | 68.7 | 260.6 |
- **mAP<sup>val</sup>** values are for single-model single-scale on [Open Image V7](https://docs.ultralytics.com/datasets/detect/open-images-v7/) dataset. <br>Reproduce by `yolo val detect data=open-images-v7.yaml device=0`
- **Speed** averaged over Open Image V7 val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) instance. <br>Reproduce by `yolo val detect data=open-images-v7.yaml batch=1 device=0|cpu`
</details>
<details><summary>Segmentation (COCO)</summary>
See [Segmentation Docs](https://docs.ultralytics.com/tasks/segment/) for usage examples with these models trained on [COCO-Seg](https://docs.ultralytics.com/datasets/segment/coco/), which include 80 pre-trained classes.

@ -89,14 +89,25 @@ YOLOv8 也可以在 Python 环境中直接使用,并接受与上述 CLI 示例
from ultralytics import YOLO
# 加载模型
model = YOLO("yolov8n.yaml") # 从头开始构建新模型
model = YOLO("yolov8n.pt") # 加载预训练模型(建议用于训练)
# 使用模型
model.train(data="coco8.yaml", epochs=3) # 训练模型
metrics = model.val() # 在验证集上评估模型性能
results = model("https://ultralytics.com/images/bus.jpg") # 对图像进行预测
success = model.export(format="onnx") # 将模型导出为 ONNX 格式
model = YOLO("yolov8n.pt")
# 训练模型
train_results = model.train(
data="coco8.yaml", # 数据配置文件的路径
epochs=100, # 训练的轮数
imgsz=640, # 训练图像大小
device="cpu", # 运行的设备,例如 device=0 或 device=0,1,2,3 或 device=cpu
)
# 在验证集上评估模型性能
metrics = model.val()
# 对图像进行目标检测
results = model("path/to/image.jpg")
results[0].show()
# 将模型导出为 ONNX 格式
path = model.export(format="onnx") # 返回导出的模型路径
```
查看 YOLOv8 [Python 文档](https://docs.ultralytics.com/usage/python/)以获取更多示例。
@ -141,23 +152,6 @@ Ultralytics 提供了 YOLOv8 的交互式笔记本,涵盖训练、验证、跟
</details>
<details><summary>检测(Open Image V7)</summary>
查看[检测文档](https://docs.ultralytics.com/tasks/detect/)以获取这些在[Open Image V7](https://docs.ultralytics.com/datasets/detect/open-images-v7/)上训练的模型的使用示例,其中包括600个预训练类别。
| 模型 | 尺寸<br><sup>(像素) | mAP<sup>验证<br>50-95 | 速度<br><sup>CPU ONNX<br>(毫秒) | 速度<br><sup>A100 TensorRT<br>(毫秒) | 参数<br><sup>(M) | 浮点运算<br><sup>(B) |
| ----------------------------------------------------------------------------------------- | ------------------- | --------------------- | ------------------------------- | ------------------------------------ | ---------------- | -------------------- |
| [YOLOv8n](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n-oiv7.pt) | 640 | 18.4 | 142.4 | 1.21 | 3.5 | 10.5 |
| [YOLOv8s](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8s-oiv7.pt) | 640 | 27.7 | 183.1 | 1.40 | 11.4 | 29.7 |
| [YOLOv8m](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8m-oiv7.pt) | 640 | 33.6 | 408.5 | 2.26 | 26.2 | 80.6 |
| [YOLOv8l](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8l-oiv7.pt) | 640 | 34.9 | 596.9 | 2.43 | 44.1 | 167.4 |
| [YOLOv8x](https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8x-oiv7.pt) | 640 | 36.3 | 860.6 | 3.56 | 68.7 | 260.6 |
- **mAP<sup>验证</sup>** 值适用于在[Open Image V7](https://docs.ultralytics.com/datasets/detect/open-images-v7/)数据集上的单模型单尺度。 <br>通过 `yolo val detect data=open-images-v7.yaml device=0` 以复现。
- **速度** 在使用[Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/)实例对Open Image V7验证图像进行平均测算。 <br>通过 `yolo val detect data=open-images-v7.yaml batch=1 device=0|cpu` 以复现。
</details>
<details><summary>分割 (COCO)</summary>
查看[分割文档](https://docs.ultralytics.com/tasks/segment/)以获取这些在[COCO-Seg](https://docs.ultralytics.com/datasets/segment/coco/)上训练的模型的使用示例,其中包括80个预训练类别。

@ -61,7 +61,7 @@ Ultralytics YOLO models return either a Python list of `Results` objects, or a m
model = YOLO("yolov8n.pt") # pretrained YOLOv8n model
# Run batched inference on a list of images
results = model(["im1.jpg", "im2.jpg"]) # return a list of Results objects
results = model(["image1.jpg", "image2.jpg"]) # return a list of Results objects
# Process results list
for result in results:
@ -83,7 +83,7 @@ Ultralytics YOLO models return either a Python list of `Results` objects, or a m
model = YOLO("yolov8n.pt") # pretrained YOLOv8n model
# Run batched inference on a list of images
results = model(["im1.jpg", "im2.jpg"], stream=True) # return a generator of Results objects
results = model(["image1.jpg", "image2.jpg"], stream=True) # return a generator of Results objects
# Process results generator
for result in results:
@ -109,8 +109,8 @@ YOLOv8 can process different types of input sources for inference, as shown in t
| image | `'image.jpg'` | `str` or `Path` | Single image file. |
| URL | `'https://ultralytics.com/images/bus.jpg'` | `str` | URL to an image. |
| screenshot | `'screen'` | `str` | Capture a screenshot. |
| PIL | `Image.open('im.jpg')` | `PIL.Image` | HWC format with RGB channels. |
| OpenCV | `cv2.imread('im.jpg')` | `np.ndarray` | HWC format with BGR channels `uint8 (0-255)`. |
| PIL | `Image.open('image.jpg')` | `PIL.Image` | HWC format with RGB channels. |
| OpenCV | `cv2.imread('image.jpg')` | `np.ndarray` | HWC format with BGR channels `uint8 (0-255)`. |
| numpy | `np.zeros((640,1280,3))` | `np.ndarray` | HWC format with BGR channels `uint8 (0-255)`. |
| torch | `torch.zeros(16,3,320,640)` | `torch.Tensor` | BCHW format with RGB channels `float32 (0.0-1.0)`. |
| CSV | `'sources.csv'` | `str` or `Path` | CSV file containing paths to images, videos, or directories. |
@ -710,16 +710,16 @@ When using YOLO models in a multi-threaded application, it's important to instan
from ultralytics import YOLO
def thread_safe_predict(image_path):
def thread_safe_predict(model, image_path):
"""Performs thread-safe prediction on an image using a locally instantiated YOLO model."""
local_model = YOLO("yolov8n.pt")
results = local_model.predict(image_path)
model = YOLO(model)
results = model.predict(image_path)
# Process results
# Starting threads that each have their own model instance
Thread(target=thread_safe_predict, args=("image1.jpg",)).start()
Thread(target=thread_safe_predict, args=("image2.jpg",)).start()
Thread(target=thread_safe_predict, args=("yolov8n.pt", "image1.jpg")).start()
Thread(target=thread_safe_predict, args=("yolov8n.pt", "image2.jpg")).start()
```
For an in-depth look at thread-safe inference with YOLO models and step-by-step instructions, please refer to our [YOLO Thread-Safe Inference Guide](../guides/yolo-thread-safe-inference.md). This guide will provide you with all the necessary information to avoid common pitfalls and ensure that your multi-threaded inference runs smoothly.

@ -290,63 +290,35 @@ Finally, after all threads have completed their task, the windows displaying the
from ultralytics import YOLO
# Define model names and video sources
MODEL_NAMES = ["yolov8n.pt", "yolov8n-seg.pt"]
SOURCES = ["path/to/video.mp4", "0"] # local video, 0 for webcam
def run_tracker_in_thread(filename, model, file_index):
"""
Runs a video file or webcam stream concurrently with the YOLOv8 model using threading.
This function captures video frames from a given file or camera source and utilizes the YOLOv8 model for object
tracking. The function runs in its own thread for concurrent processing.
def run_tracker_in_thread(model_name, filename):
"""
Run YOLO tracker in its own thread for concurrent processing.
Args:
model_name (str): The YOLOv8 model object.
filename (str): The path to the video file or the identifier for the webcam/external camera source.
model (obj): The YOLOv8 model object.
file_index (int): An index to uniquely identify the file being processed, used for display purposes.
Note:
Press 'q' to quit the video display window.
"""
video = cv2.VideoCapture(filename) # Read the video file
while True:
ret, frame = video.read() # Read the video frames
# Exit the loop if no more frames in either video
if not ret:
break
# Track objects in frames if available
results = model.track(frame, persist=True)
res_plotted = results[0].plot()
cv2.imshow(f"Tracking_Stream_{file_index}", res_plotted)
key = cv2.waitKey(1)
if key == ord("q"):
break
# Release video sources
video.release()
model = YOLO(model_name)
results = model.track(filename, save=True, stream=True)
for r in results:
pass
# Load the models
model1 = YOLO("yolov8n.pt")
model2 = YOLO("yolov8n-seg.pt")
# Define the video files for the trackers
video_file1 = "path/to/video1.mp4" # Path to video file, 0 for webcam
video_file2 = 0 # Path to video file, 0 for webcam, 1 for external camera
# Create and start tracker threads using a for loop
tracker_threads = []
for video_file, model_name in zip(SOURCES, MODEL_NAMES):
thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file), daemon=True)
tracker_threads.append(thread)
thread.start()
# Create the tracker threads
tracker_thread1 = threading.Thread(target=run_tracker_in_thread, args=(video_file1, model1, 1), daemon=True)
tracker_thread2 = threading.Thread(target=run_tracker_in_thread, args=(video_file2, model2, 2), daemon=True)
# Start the tracker threads
tracker_thread1.start()
tracker_thread2.start()
# Wait for the tracker threads to finish
tracker_thread1.join()
tracker_thread2.join()
# Wait for all tracker threads to finish
for thread in tracker_threads:
thread.join()
# Clean up and close windows
cv2.destroyAllWindows()
@ -408,35 +380,37 @@ To run object tracking on multiple video streams simultaneously, you can use Pyt
from ultralytics import YOLO
# Define model names and video sources
MODEL_NAMES = ["yolov8n.pt", "yolov8n-seg.pt"]
SOURCES = ["path/to/video.mp4", "0"] # local video, 0 for webcam
def run_tracker_in_thread(filename, model, file_index):
video = cv2.VideoCapture(filename)
while True:
ret, frame = video.read()
if not ret:
break
results = model.track(frame, persist=True)
res_plotted = results[0].plot()
cv2.imshow(f"Tracking_Stream_{file_index}", res_plotted)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
video.release()
def run_tracker_in_thread(model_name, filename):
"""
Run YOLO tracker in its own thread for concurrent processing.
model1 = YOLO("yolov8n.pt")
model2 = YOLO("yolov8n-seg.pt")
video_file1 = "path/to/video1.mp4"
video_file2 = 0 # Path to a second video file, or 0 for a webcam
Args:
model_name (str): The YOLOv8 model object.
filename (str): The path to the video file or the identifier for the webcam/external camera source.
"""
model = YOLO(model_name)
results = model.track(filename, save=True, stream=True)
for r in results:
pass
tracker_thread1 = threading.Thread(target=run_tracker_in_thread, args=(video_file1, model1, 1), daemon=True)
tracker_thread2 = threading.Thread(target=run_tracker_in_thread, args=(video_file2, model2, 2), daemon=True)
tracker_thread1.start()
tracker_thread2.start()
# Create and start tracker threads using a for loop
tracker_threads = []
for video_file, model_name in zip(SOURCES, MODEL_NAMES):
thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file), daemon=True)
tracker_threads.append(thread)
thread.start()
tracker_thread1.join()
tracker_thread2.join()
# Wait for all tracker threads to finish
for thread in tracker_threads:
thread.join()
# Clean up and close windows
cv2.destroyAllWindows()
```

@ -44,8 +44,8 @@ Harness `detect.py` for versatile inference on various sources. It automatically
```bash
python detect.py --weights yolov5s.pt --source 0 # webcam
img.jpg # image
vid.mp4 # video
image.jpg # image
video.mp4 # video
screen # screenshot
path/ # directory
list.txt # list of images

@ -45,7 +45,7 @@
username: zhixuwei
49699333+dependabot[bot]@users.noreply.github.com:
avatar: https://avatars.githubusercontent.com/u/27347476?v=4
username: dependabot[bot]
username: dependabot
53246858+hasanghaffari93@users.noreply.github.com:
avatar: https://avatars.githubusercontent.com/u/53246858?v=4
username: hasanghaffari93

@ -1,13 +1,13 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks, is_dir_writeable
from ultralytics.utils import ASSETS, ROOT, WEIGHTS_DIR, checks
# Constants used in tests
MODEL = WEIGHTS_DIR / "path with spaces" / "yolov8n.pt" # test spaces in path
CFG = "yolov8n.yaml"
SOURCE = ASSETS / "bus.jpg"
SOURCES_LIST = [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]
TMP = (ROOT / "../tests/tmp").resolve() # temp directory for test files
IS_TMP_WRITEABLE = is_dir_writeable(TMP)
CUDA_IS_AVAILABLE = checks.cuda_is_available()
CUDA_DEVICE_COUNT = checks.cuda_device_count()
@ -15,6 +15,7 @@ __all__ = (
"MODEL",
"CFG",
"SOURCE",
"SOURCES_LIST",
"TMP",
"IS_TMP_WRITEABLE",
"CUDA_IS_AVAILABLE",

@ -1,6 +1,7 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
import csv
import urllib
from copy import copy
from pathlib import Path
@ -12,7 +13,7 @@ import torch
import yaml
from PIL import Image
from tests import CFG, IS_TMP_WRITEABLE, MODEL, SOURCE, TMP
from tests import CFG, MODEL, SOURCE, SOURCES_LIST, TMP
from ultralytics import RTDETR, YOLO
from ultralytics.cfg import MODELS, TASK2DATA, TASKS
from ultralytics.data.build import load_inference_source
@ -26,11 +27,14 @@ from ultralytics.utils import (
WEIGHTS_DIR,
WINDOWS,
checks,
is_dir_writeable,
is_github_action_running,
)
from ultralytics.utils.downloads import download
from ultralytics.utils.torch_utils import TORCH_1_9
IS_TMP_WRITEABLE = is_dir_writeable(TMP) # WARNING: must be run once tests start as TMP does not exist on tests/init
def test_model_forward():
"""Test the forward pass of the YOLO model."""
@ -70,11 +74,37 @@ def test_model_profile():
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
def test_predict_txt():
"""Tests YOLO predictions with file, directory, and pattern sources listed in a text file."""
txt_file = TMP / "sources.txt"
with open(txt_file, "w") as f:
for x in [ASSETS / "bus.jpg", ASSETS, ASSETS / "*", ASSETS / "**/*.jpg"]:
f.write(f"{x}\n")
_ = YOLO(MODEL)(source=txt_file, imgsz=32)
file = TMP / "sources_multi_row.txt"
with open(file, "w") as f:
for src in SOURCES_LIST:
f.write(f"{src}\n")
results = YOLO(MODEL)(source=file, imgsz=32)
assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images
@pytest.mark.skipif(True, reason="disabled for testing")
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
def test_predict_csv_multi_row():
"""Tests YOLO predictions with sources listed in multiple rows of a CSV file."""
file = TMP / "sources_multi_row.csv"
with open(file, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["source"])
writer.writerows([[src] for src in SOURCES_LIST])
results = YOLO(MODEL)(source=file, imgsz=32)
assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images
@pytest.mark.skipif(True, reason="disabled for testing")
@pytest.mark.skipif(not IS_TMP_WRITEABLE, reason="directory is not writeable")
def test_predict_csv_single_row():
"""Tests YOLO predictions with sources listed in a single row of a CSV file."""
file = TMP / "sources_single_row.csv"
with open(file, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(SOURCES_LIST)
results = YOLO(MODEL)(source=file, imgsz=32)
assert len(results) == 7 # 1 + 2 + 2 + 2 = 7 images
@pytest.mark.parametrize("model_name", MODELS)

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.2.94"
__version__ = "8.2.95"
import os

@ -668,13 +668,14 @@ class BaseTrainer:
def final_eval(self):
"""Performs final evaluation and validation for object detection YOLO model."""
ckpt = {}
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
if self.last.is_file(): # update best.pt train_metrics from last.pt
k = "train_results"
torch.save({**torch.load(self.best), **{k: torch.load(self.last)[k]}}, self.best)
if f is self.last:
ckpt = strip_optimizer(f)
elif f is self.best:
k = "train_results" # update best.pt train_metrics from last.pt
strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
LOGGER.info(f"\nValidating {f}...")
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)

@ -759,6 +759,10 @@ class SafeClass:
"""Initialize SafeClass instance, ignoring all arguments."""
pass
def __call__(self, *args, **kwargs):
"""Run SafeClass instance, ignoring all arguments."""
pass
class SafeUnpickler(pickle.Unpickler):
"""Custom Unpickler that replaces unknown classes with SafeClass."""

@ -533,16 +533,17 @@ class ModelEMA:
copy_attr(self.ema, model, include, exclude)
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: dict = None) -> dict:
"""
Strip optimizer from 'f' to finalize training, optionally save as 's'.
Args:
f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.
updates (dict): a dictionary of updates to overlay onto the checkpoint before saving.
Returns:
None
(dict): The combined checkpoint dictionary.
Example:
```python
@ -562,9 +563,9 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
assert "model" in x, "'model' missing from checkpoint"
except Exception as e:
LOGGER.warning(f"WARNING ⚠ Skipping {f}, not a valid Ultralytics model: {e}")
return
return {}
updates = {
metadata = {
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
@ -591,9 +592,11 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
# x['model'].args = x['train_args']
# Save
torch.save({**updates, **x}, s or f, use_dill=False) # combine dicts (prefer to the right)
combined = {**metadata, **x, **(updates or {})}
torch.save(combined, s or f, use_dill=False) # combine dicts (prefer to the right)
mb = os.path.getsize(s or f) / 1e6 # file size
LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
return combined
def convert_optimizer_state_dict_to_fp16(state_dict):

Loading…
Cancel
Save