Merge branch 'main' into fix-linker-issue-cpp-inference

pull/17521/head
Glenn Jocher 2 weeks ago committed by GitHub
commit 862e808a85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      docs/en/guides/nvidia-jetson.md
  2. 85
      docs/en/guides/region-counting.md
  3. 6
      docs/en/hub/models.md
  4. 6
      docs/en/integrations/sony-imx500.md
  5. 16
      docs/en/reference/solutions/region_counter.md
  6. 3
      docs/mkdocs_github_authors.yaml
  7. 1
      mkdocs.yml
  8. 2
      ultralytics/__init__.py
  9. 3
      ultralytics/cfg/__init__.py
  10. 2
      ultralytics/engine/results.py
  11. 2
      ultralytics/solutions/__init__.py
  12. 2
      ultralytics/solutions/analytics.py
  13. 112
      ultralytics/solutions/region_counter.py
  14. 2
      ultralytics/solutions/solutions.py
  15. 40
      ultralytics/utils/tal.py

@ -54,7 +54,7 @@ The first step after getting your hands on an NVIDIA Jetson device is to flash N
1. If you own an official NVIDIA Development Kit such as the Jetson Orin Nano Developer Kit, you can [download an image and prepare an SD card with JetPack for booting the device](https://developer.nvidia.com/embedded/learn/get-started-jetson-orin-nano-devkit).
2. If you own any other NVIDIA Development Kit, you can [flash JetPack to the device using SDK Manager](https://docs.nvidia.com/sdk-manager/install-with-sdkm-jetson/index.html).
3. If you own a Seeed Studio reComputer J4012 device, you can [flash JetPack to the included SSD](https://wiki.seeedstudio.com/reComputer_J4012_Flash_Jetpack/) and if you own a Seeed Studio reComputer J1020 v2 device, you can [flash JetPack to the eMMC/ SSD](https://wiki.seeedstudio.com/reComputer_J2021_J202_Flash_Jetpack/).
3. If you own a Seeed Studio reComputer J4012 device, you can [flash JetPack to the included SSD](https://wiki.seeedstudio.com/recomputer_j4012_flash_jetpack/) and if you own a Seeed Studio reComputer J1020 v2 device, you can [flash JetPack to the eMMC/ SSD](https://wiki.seeedstudio.com/reComputer_J2021_J202_Flash_Jetpack/).
4. If you own any other third party device powered by the NVIDIA Jetson module, it is recommended to follow [command-line flashing](https://docs.nvidia.com/jetson/archives/r35.5.0/DeveloperGuide/IN/QuickStart.html).
!!! note

@ -34,56 +34,65 @@ keywords: object counting, regions, YOLOv8, computer vision, Ultralytics, effici
| ![People Counting in Different Region using Ultralytics YOLOv8](https://github.com/ultralytics/docs/releases/download/0/people-counting-different-region-ultralytics-yolov8.avif) | ![Crowd Counting in Different Region using Ultralytics YOLOv8](https://github.com/ultralytics/docs/releases/download/0/crowd-counting-different-region-ultralytics-yolov8.avif) |
| People Counting in Different Region using Ultralytics YOLOv8 | Crowd Counting in Different Region using Ultralytics YOLOv8 |
## Steps to Run
!!! example "Region Counting Example"
### Step 1: Install Required Libraries
=== "Python"
Begin by cloning the Ultralytics repository, installing dependencies, and navigating to the local directory using the provided commands in Step 2.
```python
import cv2
from ultralytics import solutions
```bash
# Clone Ultralytics repo
git clone https://github.com/ultralytics/ultralytics
cap = cv2.VideoCapture("Path/to/video/file.mp4")
assert cap.isOpened(), "Error reading video file"
w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS))
# Navigate to the local directory
cd ultralytics/examples/YOLOv8-Region-Counter
```
# Define region points
# region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)] # Pass region as list
### Step 2: Run Region Counting Using Ultralytics YOLOv8
# pass region as dictionary
region_points = {
"region-01": [(50, 50), (250, 50), (250, 250), (50, 250)],
"region-02": [(640, 640), (780, 640), (780, 720), (640, 720)]
}
Execute the following basic commands for inference.
# Video writer
video_writer = cv2.VideoWriter("region_counting.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
???+ tip "Region is Movable"
# Init Object Counter
region = solutions.RegionCounter(
show=True,
region=region_points,
model="yolo11n.pt",
)
During video playback, you can interactively move the region within the video by clicking and dragging using the left mouse button.
# Process video
while cap.isOpened():
success, im0 = cap.read()
if not success:
print("Video frame is empty or video processing has been successfully completed.")
break
im0 = region.count(im0)
video_writer.write(im0)
```bash
# Save results
python yolov8_region_counter.py --source "path/to/video.mp4" --save-img
cap.release()
video_writer.release()
cv2.destroyAllWindows()
```
# Run model on CPU
python yolov8_region_counter.py --source "path/to/video.mp4" --device cpu
!!! tip "Ultralytics Example Code"
# Change model file
python yolov8_region_counter.py --source "path/to/video.mp4" --weights "path/to/model.pt"
The Ultralytics region counting module is available in our [examples section](https://github.com/ultralytics/ultralytics/blob/main/examples/YOLOv8-Region-Counter/yolov8_region_counter.py). You can explore this example for code customization and modify it to suit your specific use case.
# Detect specific classes (e.g., first and third classes)
python yolov8_region_counter.py --source "path/to/video.mp4" --classes 0 2
### Argument `RegionCounter`
# View results without saving
python yolov8_region_counter.py --source "path/to/video.mp4" --view-img
```
Here's a table with the `RegionCounter` arguments:
### Optional Arguments
| Name | Type | Default | Description |
| -------------------- | ------ | ------------ | --------------------------------------------------------------------------- |
| `--source` | `str` | `None` | Path to video file, for webcam 0 |
| `--line_thickness` | `int` | `2` | [Bounding Box](https://www.ultralytics.com/glossary/bounding-box) thickness |
| `--save-img` | `bool` | `False` | Save the predicted video/image |
| `--weights` | `str` | `yolov8n.pt` | Weights file path |
| `--classes` | `list` | `None` | Detect specific classes i.e. --classes 0 2 |
| `--region-thickness` | `int` | `2` | Region Box thickness |
| `--track-thickness` | `int` | `2` | Tracking line thickness |
| Name | Type | Default | Description |
| ------------ | ------ | -------------------------- | ---------------------------------------------------- |
| `model` | `str` | `None` | Path to Ultralytics YOLO Model File |
| `region` | `list` | `[(20, 400), (1260, 400)]` | List of points defining the counting region. |
| `line_width` | `int` | `2` | Line thickness for bounding boxes. |
| `show` | `bool` | `False` | Flag to control whether to display the video stream. |
## FAQ
@ -107,7 +116,7 @@ Follow these steps to run object counting in Ultralytics YOLOv8:
python yolov8_region_counter.py --source "path/to/video.mp4" --save-img
```
For more options, visit the [Run Region Counting](#steps-to-run) section.
For more options, visit the [Run Region Counting](https://github.com/ultralytics/ultralytics/blob/main/examples/YOLOv8-Region-Counter/readme.md) section.
### Why should I use Ultralytics YOLOv8 for object counting in regions?
@ -121,7 +130,7 @@ Explore deeper benefits in the [Advantages](#advantages-of-object-counting-in-re
### Can the defined regions be adjusted during video playback?
Yes, with Ultralytics YOLOv8, regions can be interactively moved during video playback. Simply click and drag with the left mouse button to reposition the region. This feature enhances flexibility for dynamic environments. Learn more in the tip section for [movable regions](#step-2-run-region-counting-using-ultralytics-yolov8).
Yes, with Ultralytics YOLOv8, regions can be interactively moved during video playback. Simply click and drag with the left mouse button to reposition the region. This feature enhances flexibility for dynamic environments. Learn more in the tip section for [movable regions](https://github.com/ultralytics/ultralytics/blob/33cdaa5782efb2bc2b5ede945771ba647882830d/examples/YOLOv8-Region-Counter/yolov8_region_counter.py#L39).
### What are some real-world applications of object counting in regions?

@ -1,7 +1,7 @@
---
comments: true
description: Explore Ultralytics HUB for easy training, analysis, preview, deployment and sharing of custom vision AI models using YOLOv8. Start training today!.
keywords: Ultralytics HUB, YOLOv8, custom AI models, model training, model deployment, model analysis, vision AI
description: Explore Ultralytics HUB for easy training, analysis, preview, deployment and sharing of custom vision AI models using YOLO11. Start training today!.
keywords: Ultralytics HUB, YOLO11, custom AI models, model training, model deployment, model analysis, vision AI
---
# Ultralytics HUB Models
@ -66,7 +66,7 @@ In this step, you have to choose the project in which you want to create your mo
!!! info
You can read more about the available [YOLOv8](https://docs.ultralytics.com/models/yolov8/) (and [YOLOv5](https://docs.ultralytics.com/models/yolov5/)) architectures in our documentation.
You can read more about the available [YOLO models](https://docs.ultralytics.com/models) and architectures in our documentation.
By default, your model will use a pre-trained model (trained on the [COCO](https://docs.ultralytics.com/datasets/detect/coco/) dataset) to reduce training time. You can change this behavior and tweak your model's configuration by opening the **Advanced Model Configuration** accordion.

@ -29,7 +29,7 @@ The IMX500 works with quantized models. Quantization makes models smaller and fa
**IMX500 Key Features:**
- **Metadata Output:** Instead of transmitting full images, the IMX500 outputs only metadata, minimizing data size, reducing bandwidth, and lowering costs.
- **Metadata Output:** Instead of transmitting images only, the IMX500 can output both image and metadata (inference result), and can output metadata only for minimizing data size, reducing bandwidth, and lowering costs.
- **Addresses Privacy Concerns:** By processing data on the device, the IMX500 addresses privacy concerns, ideal for human-centric applications like person counting and occupancy tracking.
- **Real-time Processing:** Fast, on-sensor processing supports real-time decisions, perfect for edge AI applications such as autonomous systems.
@ -247,7 +247,7 @@ Export to IMX500 format has wide applicability across industries. Here are some
## Conclusion
Exporting Ultralytics YOLOv8 models to Sony's IMX500 format allows you to deploy your models for efficient inference on IMX500-based cameras. By leveraging advanced quantization and pruning techniques, you can reduce model size and improve inference speed without significantly compromising accuracy.
Exporting Ultralytics YOLOv8 models to Sony's IMX500 format allows you to deploy your models for efficient inference on IMX500-based cameras. By leveraging advanced quantization techniques, you can reduce model size and improve inference speed without significantly compromising accuracy.
For more information and detailed guidelines, refer to Sony's [IMX500 website](https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera).
@ -271,7 +271,7 @@ The export process will create a directory containing the necessary files for de
The IMX500 format offers several important advantages for edge deployment:
- On-chip AI processing reduces latency and power consumption
- Outputs metadata instead of full images, minimizing bandwidth usage
- Outputs both image and metadata (inference result) instead of images only
- Enhanced privacy by processing data locally without cloud dependency
- Real-time processing capabilities ideal for time-sensitive applications
- Optimized quantization for efficient model deployment on resource-constrained devices

@ -0,0 +1,16 @@
---
description: Explore the Ultralytics Object Counter for real-time video streams. Learn about initializing parameters, tracking objects, and more.
keywords: Ultralytics, Object Counter, Real-time Tracking, Video Stream, Python, Object Detection
---
# Reference for `ultralytics/solutions/region_counter.py`
!!! note
This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/region_counter.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/region_counter.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/solutions/region_counter.py) 🛠. Thank you 🙏!
<br>
## ::: ultralytics.solutions.region_counter.RegionCounter
<br><br>

@ -10,6 +10,9 @@
130829914+IvorZhu331@users.noreply.github.com:
avatar: https://avatars.githubusercontent.com/u/130829914?v=4
username: IvorZhu331
131249114+ServiAmirPM@users.noreply.github.com:
avatar: https://avatars.githubusercontent.com/u/131249114?v=4
username: ServiAmirPM
131261051+MatthewNoyce@users.noreply.github.com:
avatar: https://avatars.githubusercontent.com/u/131261051?v=4
username: MatthewNoyce

@ -571,6 +571,7 @@ nav:
- solutions: reference/solutions/solutions.md
- speed_estimation: reference/solutions/speed_estimation.md
- streamlit_inference: reference/solutions/streamlit_inference.md
- region_counter: reference/solutions/region_counter.md
- trackers:
- basetrack: reference/trackers/basetrack.md
- bot_sort: reference/trackers/bot_sort.md

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.29"
__version__ = "8.3.30"
import os

@ -671,6 +671,9 @@ def handle_yolo_solutions(args: List[str]) -> None:
)
s_n = "count" # Default solution if none provided
if args and args[0] == "help": # Add check for return if user call `yolo solutions help`
return
cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source
from ultralytics import solutions # import ultralytics solutions

@ -750,7 +750,7 @@ class Results(SimpleClass):
save_one_box(
d.xyxy,
self.orig_img.copy(),
file=Path(save_dir) / self.names[int(d.cls)] / f"{Path(file_name)}.jpg",
file=Path(save_dir) / self.names[int(d.cls)] / Path(file_name).with_suffix(".jpg"),
BGR=True,
)

@ -7,6 +7,7 @@ from .heatmap import Heatmap
from .object_counter import ObjectCounter
from .parking_management import ParkingManagement, ParkingPtsSelection
from .queue_management import QueueManager
from .region_counter import RegionCounter
from .speed_estimation import SpeedEstimator
from .streamlit_inference import inference
@ -21,4 +22,5 @@ __all__ = (
"SpeedEstimator",
"Analytics",
"inference",
"RegionCounter",
)

@ -54,7 +54,7 @@ class Analytics(BaseSolution):
self.y_label = "Total Counts"
# Predefined data
self.bg_color = "#00F344" # background color of frame
self.bg_color = "#F3F3F3" # background color of frame
self.fg_color = "#111E68" # foreground color of frame
self.title = "Ultralytics Solutions" # window name
self.max_points = 45 # maximum points to be drawn on window

@ -0,0 +1,112 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from ultralytics.solutions.solutions import BaseSolution
from ultralytics.utils.plotting import Annotator, colors
class RegionCounter(BaseSolution):
"""
A class designed for real-time counting of objects within user-defined regions in a video stream.
This class inherits from `BaseSolution` and offers functionalities to define polygonal regions in a video
frame, track objects, and count those objects that pass through each defined region. This makes it useful
for applications that require counting in specified areas, such as monitoring zones or segmented sections.
Attributes:
region_template (dict): A template for creating new counting regions with default attributes including
the name, polygon coordinates, and display colors.
counting_regions (list): A list storing all defined regions, where each entry is based on `region_template`
and includes specific region settings like name, coordinates, and color.
Methods:
add_region: Adds a new counting region with specified attributes, such as the region's name, polygon points,
region color, and text color.
count: Processes video frames to count objects in each region, drawing regions and displaying counts
on the frame. Handles object detection, region definition, and containment checks.
"""
def __init__(self, **kwargs):
"""Initializes the RegionCounter class for real-time counting in different regions of the video streams."""
super().__init__(**kwargs)
self.region_template = {
"name": "Default Region",
"polygon": None,
"counts": 0,
"dragging": False,
"region_color": (255, 255, 255),
"text_color": (0, 0, 0),
}
self.counting_regions = []
def add_region(self, name, polygon_points, region_color, text_color):
"""
Adds a new region to the counting list based on the provided template with specific attributes.
Args:
name (str): Name assigned to the new region.
polygon_points (list[tuple]): List of (x, y) coordinates defining the region's polygon.
region_color (tuple): BGR color for region visualization.
text_color (tuple): BGR color for the text within the region.
"""
region = self.region_template.copy()
region.update(
{
"name": name,
"polygon": self.Polygon(polygon_points),
"region_color": region_color,
"text_color": text_color,
}
)
self.counting_regions.append(region)
def count(self, im0):
"""
Processes the input frame to detect and count objects within each defined region.
Args:
im0 (numpy.ndarray): Input image frame where objects and regions are annotated.
Returns:
im0 (numpy.ndarray): Processed image frame with annotated counting information.
"""
self.annotator = Annotator(im0, line_width=self.line_width)
self.extract_tracks(im0)
# Region initialization and conversion
if self.region is None:
self.initialize_region()
regions = {"Region#01": self.region}
else:
regions = self.region if isinstance(self.region, dict) else {"Region#01": self.region}
# Draw regions and process counts for each defined area
for idx, (region_name, reg_pts) in enumerate(regions.items(), start=1):
color = colors(idx, True)
self.annotator.draw_region(reg_pts=reg_pts, color=color, thickness=self.line_width * 2)
self.add_region(region_name, reg_pts, color, self.annotator.get_txt_color())
# Prepare regions for containment check
for region in self.counting_regions:
region["prepared_polygon"] = self.prep(region["polygon"])
# Process bounding boxes and count objects within each region
for box, cls in zip(self.boxes, self.clss):
self.annotator.box_label(box, label=self.names[cls], color=colors(cls, True))
bbox_center = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2)
for region in self.counting_regions:
if region["prepared_polygon"].contains(self.Point(bbox_center)):
region["counts"] += 1
# Display counts in each region
for region in self.counting_regions:
self.annotator.text_label(
region["polygon"].bounds,
label=str(region["counts"]),
color=region["region_color"],
txt_color=region["text_color"],
)
region["counts"] = 0 # Reset count for next frame
self.display_output(im0)
return im0

@ -50,10 +50,12 @@ class BaseSolution:
"""
check_requirements("shapely>=2.0.0")
from shapely.geometry import LineString, Point, Polygon
from shapely.prepared import prep
self.LineString = LineString
self.Polygon = Polygon
self.Point = Point
self.prep = prep
# Load config and update with args
DEFAULT_SOL_DICT.update(kwargs)

@ -58,17 +58,45 @@ class TaskAlignedAssigner(nn.Module):
"""
self.bs = pd_scores.shape[0]
self.n_max_boxes = gt_bboxes.shape[1]
device = gt_bboxes.device
if self.n_max_boxes == 0:
device = gt_bboxes.device
return (
torch.full_like(pd_scores[..., 0], self.bg_idx).to(device),
torch.zeros_like(pd_bboxes).to(device),
torch.zeros_like(pd_scores).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device),
torch.full_like(pd_scores[..., 0], self.bg_idx),
torch.zeros_like(pd_bboxes),
torch.zeros_like(pd_scores),
torch.zeros_like(pd_scores[..., 0]),
torch.zeros_like(pd_scores[..., 0]),
)
try:
return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
except torch.OutOfMemoryError:
# Move tensors to CPU, compute, then move back to original device
cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
result = self._forward(*cpu_tensors)
return tuple(t.to(device) for t in result)
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
"""
Compute the task-aligned assignment. Reference code is available at
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py.
Args:
pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
anc_points (Tensor): shape(num_total_anchors, 2)
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
Returns:
target_labels (Tensor): shape(bs, num_total_anchors)
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
fg_mask (Tensor): shape(bs, num_total_anchors)
target_gt_idx (Tensor): shape(bs, num_total_anchors)
"""
mask_pos, align_metric, overlaps = self.get_pos_mask(
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
)

Loading…
Cancel
Save