diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 07f8539a75..04463a9b97 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -102,21 +102,19 @@ jobs: python-version: ["3.11"] model: [yolo11n] steps: + - uses: astral-sh/setup-uv@v3 - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - cache: "pip" # caching pip dependencies - name: Install requirements shell: bash # for Windows compatibility run: | - # Warnings: uv causes numpy errors during benchmarking - python -m pip install --upgrade pip wheel - pip install -e ".[export]" "coverage[toml]" --extra-index-url https://download.pytorch.org/whl/cpu + uv pip install --system -e ".[export]" "coverage[toml]" --extra-index-url https://download.pytorch.org/whl/cpu --index-strategy unsafe-first-match - name: Check environment run: | yolo checks - pip list + uv pip list - name: Benchmark DetectionModel shell: bash run: coverage run -a --source=ultralytics -m ultralytics.cfg.__init__ benchmark model='path with spaces/${{ matrix.model }}.pt' imgsz=160 verbose=0.309 diff --git a/docs/en/guides/instance-segmentation-and-tracking.md b/docs/en/guides/instance-segmentation-and-tracking.md index a910a21d8f..12cd7477a6 100644 --- a/docs/en/guides/instance-segmentation-and-tracking.md +++ b/docs/en/guides/instance-segmentation-and-tracking.md @@ -82,15 +82,11 @@ There are two types of instance segmentation tracking available in the Ultralyti === "Instance Segmentation with Object Tracking" ```python - from collections import defaultdict - import cv2 from ultralytics import YOLO from ultralytics.utils.plotting import Annotator, colors - track_history = defaultdict(lambda: []) - model = YOLO("yolo11n-seg.pt") # segmentation model cap = cv2.VideoCapture("path/to/video/file.mp4") w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) @@ -205,15 +201,11 @@ To implement object tracking, use the `model.track` method and ensure that each === "Python" ```python - from collections import defaultdict - import cv2 from ultralytics import YOLO from ultralytics.utils.plotting import Annotator, colors - track_history = defaultdict(lambda: []) - model = YOLO("yolo11n-seg.pt") # segmentation model cap = cv2.VideoCapture("path/to/video/file.mp4") w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) diff --git a/docs/en/guides/raspberry-pi.md b/docs/en/guides/raspberry-pi.md index a834d8f074..ef41d0f8cf 100644 --- a/docs/en/guides/raspberry-pi.md +++ b/docs/en/guides/raspberry-pi.md @@ -142,9 +142,10 @@ YOLO11 benchmarks were run by the Ultralytics team on nine different model forma We have only included benchmarks for YOLO11n and YOLO11s models because other models sizes are too big to run on the Raspberry Pis and does not offer decent performance. -
- YOLO11 benchmarks on RPi 5 -
+
+ YOLO11 benchmarks on RPi 5 +
Benchmarked with Ultralytics v8.3.39
+
### Detailed Comparison Table @@ -156,29 +157,33 @@ The below table represents the benchmark results for two different models (YOLO1 | Format | Status | Size on disk (MB) | mAP50-95(B) | Inference time (ms/im) | |---------------|--------|-------------------|-------------|------------------------| - | PyTorch | ✅ | 5.4 | 0.61 | 524.828 | - | TorchScript | ✅ | 10.5 | 0.6082 | 666.874 | - | ONNX | ✅ | 10.2 | 0.6082 | 181.818 | - | OpenVINO | ✅ | 10.4 | 0.6082 | 530.224 | - | TF SavedModel | ✅ | 25.8 | 0.6082 | 405.964 | - | TF GraphDef | ✅ | 10.3 | 0.6082 | 473.558 | - | TF Lite | ✅ | 10.3 | 0.6082 | 324.158 | - | PaddlePaddle | ✅ | 20.4 | 0.6082 | 644.312 | - | NCNN | ✅ | 10.2 | 0.6106 | 93.938 | + | PyTorch | ✅ | 5.4 | 0.6100 | 405.238 | + | TorchScript | ✅ | 10.5 | 0.6082 | 526.628 | + | ONNX | ✅ | 10.2 | 0.6082 | 168.082 | + | OpenVINO | ✅ | 10.4 | 0.6082 | 81.192 | + | TF SavedModel | ✅ | 25.8 | 0.6082 | 377.968 | + | TF GraphDef | ✅ | 10.3 | 0.6082 | 487.244 | + | TF Lite | ✅ | 10.3 | 0.6082 | 317.398 | + | PaddlePaddle | ✅ | 20.4 | 0.6082 | 561.892 | + | MNN | ✅ | 10.1 | 0.6106 | 112.554 | + | NCNN | ✅ | 10.2 | 0.6106 | 88.026 | === "YOLO11s" | Format | Status | Size on disk (MB) | mAP50-95(B) | Inference time (ms/im) | |---------------|--------|-------------------|-------------|------------------------| - | PyTorch | ✅ | 18.4 | 0.7526 | 1226.426 | - | TorchScript | ✅ | 36.5 | 0.7416 | 1507.95 | - | ONNX | ✅ | 36.3 | 0.7416 | 415.24 | - | OpenVINO | ✅ | 36.4 | 0.7416 | 1167.102 | - | TF SavedModel | ✅ | 91.1 | 0.7416 | 776.14 | - | TF GraphDef | ✅ | 36.4 | 0.7416 | 1014.396 | - | TF Lite | ✅ | 36.4 | 0.7416 | 845.934 | - | PaddlePaddle | ✅ | 72.5 | 0.7416 | 1567.824 | - | NCNN | ✅ | 36.2 | 0.7419 | 197.358 | + | PyTorch | ✅ | 18.4 | 0.7526 | 1011.60 | + | TorchScript | ✅ | 36.5 | 0.7416 | 1268.502 | + | ONNX | ✅ | 36.3 | 0.7416 | 324.17 | + | OpenVINO | ✅ | 36.4 | 0.7416 | 179.324 | + | TF SavedModel | ✅ | 91.1 | 0.7416 | 714.382 | + | TF GraphDef | ✅ | 36.4 | 0.7416 | 1019.83 | + | TF Lite | ✅ | 36.4 | 0.7416 | 849.86 | + | PaddlePaddle | ✅ | 72.5 | 0.7416 | 1276.34 | + | MNN | ✅ | 36.2 | 0.7409 | 273.032 | + | NCNN | ✅ | 36.2 | 0.7419 | 194.858 | + + Benchmarked with Ultralytics `v8.3.39` ## Reproduce Our Results diff --git a/docs/en/guides/trackzone.md b/docs/en/guides/trackzone.md new file mode 100644 index 0000000000..ec98221385 --- /dev/null +++ b/docs/en/guides/trackzone.md @@ -0,0 +1,160 @@ +--- +comments: true +description: Discover how TrackZone leverages Ultralytics YOLO11 to precisely track objects within specific zones, enabling real-time insights for crowd analysis, surveillance, and targeted monitoring. +keywords: TrackZone, object tracking, YOLO11, Ultralytics, real-time object detection, AI, deep learning, crowd analysis, surveillance, zone-based tracking, resource optimization +--- + +# TrackZone using Ultralytics YOLO11 + +## What is TrackZone? + +TrackZone specializes in monitoring objects within designated areas of a frame instead of the whole frame. Built on [Ultralytics YOLO11](https://github.com/ultralytics/ultralytics/), it integrates object detection and tracking specifically within zones for videos and live camera feeds. YOLO11's advanced algorithms and [deep learning](https://www.ultralytics.com/glossary/deep-learning-dl) technologies make it a perfect choice for real-time use cases, offering precise and efficient object tracking in applications like crowd monitoring and surveillance. + +## Advantages of Object Tracking in Zones (TrackZone) + +- **Targeted Analysis:** Tracking objects within specific zones allows for more focused insights, enabling precise monitoring and analysis of areas of interest, such as entry points or restricted zones. +- **Improved Efficiency:** By narrowing the tracking scope to defined zones, TrackZone reduces computational overhead, ensuring faster processing and optimal performance. +- **Enhanced Security:** Zonal tracking improves surveillance by monitoring critical areas, aiding in the early detection of unusual activity or security breaches. +- **Scalable Solutions:** The ability to focus on specific zones makes TrackZone adaptable to various scenarios, from retail spaces to industrial settings, ensuring seamless integration and scalability. + +## Real World Applications + +| Agriculture | Transportation | +| :-----------------------------------------------------------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| ![Plants Tracking in Field Using Ultralytics YOLO11](https://github.com/ultralytics/docs/releases/download/0/plants-tracking-in-zone-using-ultralytics-yolo11.avif) | ![Vehicles Tracking on Road using Ultralytics YOLO11](https://github.com/ultralytics/docs/releases/download/0/vehicle-tracking-in-zone-using-ultralytics-yolo11.avif) | +| Plants Tracking in Field Using Ultralytics YOLO11 | Vehicles Tracking on Road using Ultralytics YOLO11 | + +!!! example "TrackZone using YOLO11 Example" + + === "CLI" + + ```bash + # Run a trackzone example + yolo solutions trackzone show=True + + # Pass a source video + yolo solutions trackzone show=True source="path/to/video/file.mp4" + + # Pass region coordinates + yolo solutions trackzone show=True region=[(150, 150), (1130, 150), (1130, 570), (150, 570)] + ``` + + === "Python" + + ```python + import cv2 + + from ultralytics import solutions + + cap = cv2.VideoCapture("path/to/video/file.mp4") + assert cap.isOpened(), "Error reading video file" + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) + + # Define region points + region_points = [(150, 150), (1130, 150), (1130, 570), (150, 570)] + + # Video writer + video_writer = cv2.VideoWriter("object_counting_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) + + # Init TrackZone (Object Tracking in Zones, not complete frame) + trackzone = solutions.TrackZone( + show=True, # Display the output + region=region_points, # Pass region points + model="yolo11n.pt", # You can use any model that Ultralytics support, i.e. YOLOv9, YOLOv10 + # line_width=2, # Adjust the line width for bounding boxes and text display + # classes=[0, 2], # If you want to count specific classes i.e. person and car with COCO pretrained model. + ) + + # Process video + while cap.isOpened(): + success, im0 = cap.read() + if not success: + print("Video frame is empty or video processing has been successfully completed.") + break + im0 = trackzone.trackzone(im0) + video_writer.write(im0) + + cap.release() + video_writer.release() + cv2.destroyAllWindows() + ``` + +### Argument `TrackZone` + +Here's a table with the `TrackZone` arguments: + +| Name | Type | Default | Description | +| ------------ | ------ | ---------------------------------------------------- | ---------------------------------------------------- | +| `model` | `str` | `None` | Path to Ultralytics YOLO Model File | +| `region` | `list` | `[(150, 150), (1130, 150), (1130, 570), (150, 570)]` | List of points defining the object tracking region. | +| `line_width` | `int` | `2` | Line thickness for bounding boxes. | +| `show` | `bool` | `False` | Flag to control whether to display the video stream. | + +### Arguments `model.track` + +{% include "macros/track-args.md" %} + +## FAQ + +### How do I track objects in a specific area or zone of a video frame using Ultralytics YOLO11? + +Tracking objects in a defined area or zone of a video frame is straightforward with Ultralytics YOLO11. Simply use the command provided below to initiate tracking. This approach ensures efficient analysis and accurate results, making it ideal for applications like surveillance, crowd management, or any scenario requiring zonal tracking. + +```bash +yolo solutions trackzone source="path/to/video/file.mp4" show=True +``` + +### How can I use TrackZone in Python with Ultralytics YOLO11? + +With just a few lines of code, you can set up object tracking in specific zones, making it easy to integrate into your projects. + +```python +import cv2 + +from ultralytics import solutions + +cap = cv2.VideoCapture("path/to/video/file.mp4") +assert cap.isOpened(), "Error reading video file" +w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) + +# Define region points +region_points = [(150, 150), (1130, 150), (1130, 570), (150, 570)] + +# Video writer +video_writer = cv2.VideoWriter("object_counting_output.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) + +# Init TrackZone (Object Tracking in Zones, not complete frame) +trackzone = solutions.TrackZone( + show=True, # Display the output + region=region_points, # Pass region points + model="yolo11n.pt", +) + +# Process video +while cap.isOpened(): + success, im0 = cap.read() + if not success: + print("Video frame is empty or video processing has been successfully completed.") + break + im0 = trackzone.trackzone(im0) + video_writer.write(im0) + +cap.release() +video_writer.release() +cv2.destroyAllWindows() +``` + +### How do I configure the zone points for video processing using Ultralytics TrackZone? + +Configuring zone points for video processing with Ultralytics TrackZone is simple and customizable. You can directly define and adjust the zones through a Python script, allowing precise control over the areas you want to monitor. + +```python +# Define region points +region_points = [(150, 150), (1130, 150), (1130, 570), (150, 570)] + +# Init TrackZone (Object Tracking in Zones, not complete frame) +trackzone = solutions.TrackZone( + show=True, # Display the output + region=region_points, # Pass region points +) +``` diff --git a/docs/en/macros/solutions-args.md b/docs/en/macros/solutions-args.md new file mode 100644 index 0000000000..0ce5d52a05 --- /dev/null +++ b/docs/en/macros/solutions-args.md @@ -0,0 +1,11 @@ +| Argument | Type | Default | Description | +| ---------------- | -------------- | -------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `region` | `list` | `[(20, 400), (1080, 400), (1080, 360), (20, 360)]` | Defines the region points for object counting, queue monitoring, trackzone or speed estimation. The points are defined as coordinates forming a polygonal area for analysis. | +| `show_in` | `bool` | `True` | Indicates whether to display objects that are counted as entering the defined region. Essential for real-world analytics, such as monitoring ingress trends. | +| `show_out` | `bool` | `True` | Indicates whether to display objects that are counted as exiting the defined region. Useful for applications requiring egress tracking and analytics. | +| `colormap` | `int or tuple` | `COLORMAP_PARULA` | Specifies the OpenCV-supported colormap for heatmap visualization. Default is `COLORMAP_PARULA`, but other colormaps can be used for different visualization preferences. | +| `up_angle` | `float` | `145.0` | Angle threshold for detecting the "up" position in workouts monitoring. Can be adjusted based on the position of keypoints for different exercises. | +| `down_angle` | `float` | `90.0` | Angle threshold for detecting the "down" position in workouts monitoring. Adjust this based on keypoint positions for specific exercises. | +| `kpts` | `list` | `[6, 8, 10]` | List of keypoints used for monitoring workouts. These keypoints correspond to body joints or parts, such as shoulders, elbows, and wrists, for exercises like push-ups, pull-ups, squats, ab-workouts. | +| `analytics_type` | `str` | `line` | Specifies the type of analytics visualization to generate. Options include `"line"`, `"pie"`, `"bar"`, or `"area"`. The default is `"line"` for trend visualization. | +| `json_file` | `str` | `None` | Path to the JSON file defining regions for parking systems or similar applications. Enables flexible configuration of analysis areas. | diff --git a/docs/en/reference/solutions/trackzone.md b/docs/en/reference/solutions/trackzone.md new file mode 100644 index 0000000000..546d61dcb0 --- /dev/null +++ b/docs/en/reference/solutions/trackzone.md @@ -0,0 +1,16 @@ +--- +description: Discover Ultralytics' TrackZone solution for real-time object tracking within defined zones. Gain insights into initializing regions, tracking objects exclusively within specific areas, and optimizing video stream processing for region-based object detection. +keywords: Ultralytics, TrackZone, Object Tracking, Zone Tracking, Region Tracking, Python, Real-time Object Tracking, Video Stream Processing, Region-based Detection +--- + +# Reference for `ultralytics/solutions/trackzone.py` + +!!! note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/trackzone.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/trackzone.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/solutions/trackzone.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.solutions.trackzone.TrackZone + +

diff --git a/docs/en/solutions/index.md b/docs/en/solutions/index.md index 4046975de7..243fbf2757 100644 --- a/docs/en/solutions/index.md +++ b/docs/en/solutions/index.md @@ -27,8 +27,10 @@ Here's our curated list of Ultralytics solutions that can be used to create awes - [Distance Calculation](../guides/distance-calculation.md) 🚀: Calculate distances between objects using [bounding box](https://www.ultralytics.com/glossary/bounding-box) centroids in YOLO11, essential for spatial analysis. - [Queue Management](../guides/queue-management.md) 🚀: Implement efficient queue management systems to minimize wait times and improve productivity using YOLO11. - [Parking Management](../guides/parking-management.md) 🚀: Organize and direct vehicle flow in parking areas with YOLO11, optimizing space utilization and user experience. -- [Analytics](../guides/analytics.md) 📊 NEW: Conduct comprehensive data analysis to discover patterns and make informed decisions, leveraging YOLO11 for descriptive, predictive, and prescriptive analytics. +- [Analytics](../guides/analytics.md) 📊: Conduct comprehensive data analysis to discover patterns and make informed decisions, leveraging YOLO11 for descriptive, predictive, and prescriptive analytics. - [Live Inference with Streamlit](../guides/streamlit-live-inference.md) 🚀: Leverage the power of YOLO11 for real-time [object detection](https://www.ultralytics.com/glossary/object-detection) directly through your web browser with a user-friendly Streamlit interface. +- [Live Inference with Streamlit](../guides/streamlit-live-inference.md) 🚀: Leverage the power of YOLO11 for real-time [object detection](https://www.ultralytics.com/glossary/object-detection) directly through your web browser with a user-friendly Streamlit interface. +- [Track Objects in Zone](../guides/trackzone.md) 🎯 NEW: Learn how to track objects within specific zones of video frames using YOLO11 for precise and efficient monitoring. ## Solutions Usage diff --git a/docs/en/usage/cfg.md b/docs/en/usage/cfg.md index 95dc8b46f2..c51863c5e0 100644 --- a/docs/en/usage/cfg.md +++ b/docs/en/usage/cfg.md @@ -130,6 +130,14 @@ It is crucial to thoughtfully configure these settings to ensure the exported mo [Export Guide](../modes/export.md){ .md-button } +## Solutions Settings + +The configuration settings for Ultralytics Solutions offer a flexible way to customize the model for various tasks like object counting, heatmap creation, workout tracking, data analysis, zone tracking, queue management, and region-based counting. These options make it easy to adjust the setup for accurate and useful results tailored to specific needs. + +{% include "macros/solutions-args.md" %} + +[Solutions Guide](../solutions/index.md){ .md-button } + ## Augmentation Settings Augmentation techniques are essential for improving the robustness and performance of YOLO models by introducing variability into the [training data](https://www.ultralytics.com/glossary/training-data), helping the model generalize better to unseen data. The following table outlines the purpose and effect of each augmentation argument: diff --git a/mkdocs.yml b/mkdocs.yml index 0d3320a230..283d52f3d9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -326,7 +326,8 @@ nav: - Distance Calculation: guides/distance-calculation.md - Queue Management: guides/queue-management.md - Parking Management: guides/parking-management.md - - Live Inference 🚀 NEW: guides/streamlit-live-inference.md + - Live Inference: guides/streamlit-live-inference.md + - Track Objects in Zone 🚀 NEW: guides/trackzone.md - Guides: - guides/index.md - YOLO Common Issues: guides/yolo-common-issues.md @@ -573,6 +574,7 @@ nav: - speed_estimation: reference/solutions/speed_estimation.md - streamlit_inference: reference/solutions/streamlit_inference.md - region_counter: reference/solutions/region_counter.md + - trackzone: reference/solutions/trackzone.md - trackers: - basetrack: reference/trackers/basetrack.md - bot_sort: reference/trackers/bot_sort.md diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 7b73169deb..601d1bb363 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.39" +__version__ = "8.3.40" import os diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 5f0222a71f..e4c239f3d4 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -41,6 +41,7 @@ SOLUTION_MAP = { "speed": ("SpeedEstimator", "estimate_speed"), "workout": ("AIGym", "monitor"), "analytics": ("Analytics", "process_data"), + "trackzone": ("TrackZone", "trackzone"), "help": None, } @@ -74,13 +75,12 @@ ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] SOLUTIONS_HELP_MSG = f""" Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo solutions' usage overview: - yolo SOLUTIONS SOLUTION ARGS - - Where SOLUTIONS (required) is a keyword - SOLUTION (optional) is one of {list(SOLUTION_MAP.keys())} - ARGS (optional) are any number of custom 'arg=value' pairs like 'show_in=True' that override defaults. - See all ARGS at https://docs.ultralytics.com/usage/cfg or with 'yolo cfg' + yolo solutions SOLUTION ARGS + Where SOLUTION (optional) is one of {list(SOLUTION_MAP.keys())} + ARGS (optional) are any number of custom 'arg=value' pairs like 'show_in=True' that override defaults + at https://docs.ultralytics.com/usage/cfg + 1. Call object counting solution yolo solutions count source="path/to/video/file.mp4" region=[(20, 400), (1080, 400), (1080, 360), (20, 360)] @@ -95,6 +95,9 @@ SOLUTIONS_HELP_MSG = f""" 5. Generate analytical graphs yolo solutions analytics analytics_type="pie" + + 6. Track Objects Within Specific Zones + yolo solutions trackzone source="path/to/video/file.mp4" region=[(150, 150), (1130, 150), (1130, 570), (150, 570)] """ CLI_HELP_MSG = f""" Arguments received: {str(['yolo'] + ARGV[1:])}. Ultralytics 'yolo' commands use the following syntax: diff --git a/ultralytics/solutions/__init__.py b/ultralytics/solutions/__init__.py index 9de61edce2..a2333129bb 100644 --- a/ultralytics/solutions/__init__.py +++ b/ultralytics/solutions/__init__.py @@ -10,6 +10,7 @@ from .queue_management import QueueManager from .region_counter import RegionCounter from .speed_estimation import SpeedEstimator from .streamlit_inference import inference +from .trackzone import TrackZone __all__ = ( "AIGym", @@ -23,4 +24,5 @@ __all__ = ( "Analytics", "inference", "RegionCounter", + "TrackZone", ) diff --git a/ultralytics/solutions/trackzone.py b/ultralytics/solutions/trackzone.py new file mode 100644 index 0000000000..0492c0a7e5 --- /dev/null +++ b/ultralytics/solutions/trackzone.py @@ -0,0 +1,68 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +import cv2 +import numpy as np + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils.plotting import Annotator, colors + + +class TrackZone(BaseSolution): + """ + A class to manage region-based object tracking in a video stream. + + This class extends the BaseSolution class and provides functionality for tracking objects within a specific region + defined by a polygonal area. Objects outside the region are excluded from tracking. It supports dynamic initialization + of the region, allowing either a default region or a user-specified polygon. + + Attributes: + region (ndarray): The polygonal region for tracking, represented as a convex hull. + + Methods: + trackzone: Processes each frame of the video, applying region-based tracking. + + Examples: + >>> tracker = TrackZone() + >>> frame = cv2.imread("frame.jpg") + >>> processed_frame = tracker.trackzone(frame) + >>> cv2.imshow("Tracked Frame", processed_frame) + """ + + def __init__(self, **kwargs): + """Initializes the TrackZone class for tracking objects within a defined region in video streams.""" + super().__init__(**kwargs) + default_region = [(150, 150), (1130, 150), (1130, 570), (150, 570)] + self.region = cv2.convexHull(np.array(self.region or default_region, dtype=np.int32)) + + def trackzone(self, im0): + """ + Processes the input frame to track objects within a defined region. + + This method initializes the annotator, creates a mask for the specified region, extracts tracks + only from the masked area, and updates tracking information. Objects outside the region are ignored. + + Args: + im0 (numpy.ndarray): The input image or frame to be processed. + + Returns: + (numpy.ndarray): The processed image with tracking id and bounding boxes annotations. + + Examples: + >>> tracker = TrackZone() + >>> frame = cv2.imread("path/to/image.jpg") + >>> tracker.trackzone(frame) + """ + self.annotator = Annotator(im0, line_width=self.line_width) # Initialize annotator + # Create a mask for the region and extract tracks from the masked image + masked_frame = cv2.bitwise_and(im0, im0, mask=cv2.fillPoly(np.zeros_like(im0[:, :, 0]), [self.region], 255)) + self.extract_tracks(masked_frame) + + cv2.polylines(im0, [self.region], isClosed=True, color=(255, 255, 255), thickness=self.line_width * 2) + + # Iterate over boxes, track ids, classes indexes list and draw bounding boxes + for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss): + self.annotator.box_label(box, label=f"{self.names[cls]}:{track_id}", color=colors(track_id, True)) + + self.display_output(im0) # display output with base class function + + return im0 # return output image for more usage