diff --git a/docs/en/guides/nvidia-jetson.md b/docs/en/guides/nvidia-jetson.md index 8a43d978b1..12e46bff61 100644 --- a/docs/en/guides/nvidia-jetson.md +++ b/docs/en/guides/nvidia-jetson.md @@ -54,7 +54,7 @@ The first step after getting your hands on an NVIDIA Jetson device is to flash N 1. If you own an official NVIDIA Development Kit such as the Jetson Orin Nano Developer Kit, you can [download an image and prepare an SD card with JetPack for booting the device](https://developer.nvidia.com/embedded/learn/get-started-jetson-orin-nano-devkit). 2. If you own any other NVIDIA Development Kit, you can [flash JetPack to the device using SDK Manager](https://docs.nvidia.com/sdk-manager/install-with-sdkm-jetson/index.html). -3. If you own a Seeed Studio reComputer J4012 device, you can [flash JetPack to the included SSD](https://wiki.seeedstudio.com/reComputer_J4012_Flash_Jetpack/) and if you own a Seeed Studio reComputer J1020 v2 device, you can [flash JetPack to the eMMC/ SSD](https://wiki.seeedstudio.com/reComputer_J2021_J202_Flash_Jetpack/). +3. If you own a Seeed Studio reComputer J4012 device, you can [flash JetPack to the included SSD](https://wiki.seeedstudio.com/recomputer_j4012_flash_jetpack/) and if you own a Seeed Studio reComputer J1020 v2 device, you can [flash JetPack to the eMMC/ SSD](https://wiki.seeedstudio.com/reComputer_J2021_J202_Flash_Jetpack/). 4. If you own any other third party device powered by the NVIDIA Jetson module, it is recommended to follow [command-line flashing](https://docs.nvidia.com/jetson/archives/r35.5.0/DeveloperGuide/IN/QuickStart.html). !!! note diff --git a/docs/en/guides/region-counting.md b/docs/en/guides/region-counting.md index a27c2b4e53..c8363d68d3 100644 --- a/docs/en/guides/region-counting.md +++ b/docs/en/guides/region-counting.md @@ -34,56 +34,65 @@ keywords: object counting, regions, YOLOv8, computer vision, Ultralytics, effici | ![People Counting in Different Region using Ultralytics YOLOv8](https://github.com/ultralytics/docs/releases/download/0/people-counting-different-region-ultralytics-yolov8.avif) | ![Crowd Counting in Different Region using Ultralytics YOLOv8](https://github.com/ultralytics/docs/releases/download/0/crowd-counting-different-region-ultralytics-yolov8.avif) | | People Counting in Different Region using Ultralytics YOLOv8 | Crowd Counting in Different Region using Ultralytics YOLOv8 | -## Steps to Run +!!! example "Region Counting Example" -### Step 1: Install Required Libraries + === "Python" -Begin by cloning the Ultralytics repository, installing dependencies, and navigating to the local directory using the provided commands in Step 2. + ```python + import cv2 + from ultralytics import solutions -```bash -# Clone Ultralytics repo -git clone https://github.com/ultralytics/ultralytics + cap = cv2.VideoCapture("Path/to/video/file.mp4") + assert cap.isOpened(), "Error reading video file" + w, h, fps = (int(cap.get(x)) for x in (cv2.CAP_PROP_FRAME_WIDTH, cv2.CAP_PROP_FRAME_HEIGHT, cv2.CAP_PROP_FPS)) -# Navigate to the local directory -cd ultralytics/examples/YOLOv8-Region-Counter -``` + # Define region points + # region_points = [(20, 400), (1080, 404), (1080, 360), (20, 360)] # Pass region as list -### Step 2: Run Region Counting Using Ultralytics YOLOv8 + # pass region as dictionary + region_points = { + "region-01": [(50, 50), (250, 50), (250, 250), (50, 250)], + "region-02": [(640, 640), (780, 640), (780, 720), (640, 720)] + } -Execute the following basic commands for inference. + # Video writer + video_writer = cv2.VideoWriter("region_counting.avi", cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) -???+ tip "Region is Movable" + # Init Object Counter + region = solutions.RegionCounter( + show=True, + region=region_points, + model="yolo11n.pt", + ) - During video playback, you can interactively move the region within the video by clicking and dragging using the left mouse button. + # Process video + while cap.isOpened(): + success, im0 = cap.read() + if not success: + print("Video frame is empty or video processing has been successfully completed.") + break + im0 = region.count(im0) + video_writer.write(im0) -```bash -# Save results -python yolov8_region_counter.py --source "path/to/video.mp4" --save-img + cap.release() + video_writer.release() + cv2.destroyAllWindows() + ``` -# Run model on CPU -python yolov8_region_counter.py --source "path/to/video.mp4" --device cpu +!!! tip "Ultralytics Example Code" -# Change model file -python yolov8_region_counter.py --source "path/to/video.mp4" --weights "path/to/model.pt" + The Ultralytics region counting module is available in our [examples section](https://github.com/ultralytics/ultralytics/blob/main/examples/YOLOv8-Region-Counter/yolov8_region_counter.py). You can explore this example for code customization and modify it to suit your specific use case. -# Detect specific classes (e.g., first and third classes) -python yolov8_region_counter.py --source "path/to/video.mp4" --classes 0 2 +### Argument `RegionCounter` -# View results without saving -python yolov8_region_counter.py --source "path/to/video.mp4" --view-img -``` +Here's a table with the `RegionCounter` arguments: -### Optional Arguments - -| Name | Type | Default | Description | -| -------------------- | ------ | ------------ | --------------------------------------------------------------------------- | -| `--source` | `str` | `None` | Path to video file, for webcam 0 | -| `--line_thickness` | `int` | `2` | [Bounding Box](https://www.ultralytics.com/glossary/bounding-box) thickness | -| `--save-img` | `bool` | `False` | Save the predicted video/image | -| `--weights` | `str` | `yolov8n.pt` | Weights file path | -| `--classes` | `list` | `None` | Detect specific classes i.e. --classes 0 2 | -| `--region-thickness` | `int` | `2` | Region Box thickness | -| `--track-thickness` | `int` | `2` | Tracking line thickness | +| Name | Type | Default | Description | +| ------------ | ------ | -------------------------- | ---------------------------------------------------- | +| `model` | `str` | `None` | Path to Ultralytics YOLO Model File | +| `region` | `list` | `[(20, 400), (1260, 400)]` | List of points defining the counting region. | +| `line_width` | `int` | `2` | Line thickness for bounding boxes. | +| `show` | `bool` | `False` | Flag to control whether to display the video stream. | ## FAQ @@ -107,7 +116,7 @@ Follow these steps to run object counting in Ultralytics YOLOv8: python yolov8_region_counter.py --source "path/to/video.mp4" --save-img ``` -For more options, visit the [Run Region Counting](#steps-to-run) section. +For more options, visit the [Run Region Counting](https://github.com/ultralytics/ultralytics/blob/main/examples/YOLOv8-Region-Counter/readme.md) section. ### Why should I use Ultralytics YOLOv8 for object counting in regions? @@ -121,7 +130,7 @@ Explore deeper benefits in the [Advantages](#advantages-of-object-counting-in-re ### Can the defined regions be adjusted during video playback? -Yes, with Ultralytics YOLOv8, regions can be interactively moved during video playback. Simply click and drag with the left mouse button to reposition the region. This feature enhances flexibility for dynamic environments. Learn more in the tip section for [movable regions](#step-2-run-region-counting-using-ultralytics-yolov8). +Yes, with Ultralytics YOLOv8, regions can be interactively moved during video playback. Simply click and drag with the left mouse button to reposition the region. This feature enhances flexibility for dynamic environments. Learn more in the tip section for [movable regions](https://github.com/ultralytics/ultralytics/blob/33cdaa5782efb2bc2b5ede945771ba647882830d/examples/YOLOv8-Region-Counter/yolov8_region_counter.py#L39). ### What are some real-world applications of object counting in regions? diff --git a/docs/en/hub/models.md b/docs/en/hub/models.md index db098669ac..1533a19ff2 100644 --- a/docs/en/hub/models.md +++ b/docs/en/hub/models.md @@ -1,7 +1,7 @@ --- comments: true -description: Explore Ultralytics HUB for easy training, analysis, preview, deployment and sharing of custom vision AI models using YOLOv8. Start training today!. -keywords: Ultralytics HUB, YOLOv8, custom AI models, model training, model deployment, model analysis, vision AI +description: Explore Ultralytics HUB for easy training, analysis, preview, deployment and sharing of custom vision AI models using YOLO11. Start training today!. +keywords: Ultralytics HUB, YOLO11, custom AI models, model training, model deployment, model analysis, vision AI --- # Ultralytics HUB Models @@ -66,7 +66,7 @@ In this step, you have to choose the project in which you want to create your mo !!! info - You can read more about the available [YOLOv8](https://docs.ultralytics.com/models/yolov8/) (and [YOLOv5](https://docs.ultralytics.com/models/yolov5/)) architectures in our documentation. + You can read more about the available [YOLO models](https://docs.ultralytics.com/models) and architectures in our documentation. By default, your model will use a pre-trained model (trained on the [COCO](https://docs.ultralytics.com/datasets/detect/coco/) dataset) to reduce training time. You can change this behavior and tweak your model's configuration by opening the **Advanced Model Configuration** accordion. diff --git a/docs/en/integrations/sony-imx500.md b/docs/en/integrations/sony-imx500.md index 335daf51fc..88338ebb67 100644 --- a/docs/en/integrations/sony-imx500.md +++ b/docs/en/integrations/sony-imx500.md @@ -29,7 +29,7 @@ The IMX500 works with quantized models. Quantization makes models smaller and fa **IMX500 Key Features:** -- **Metadata Output:** Instead of transmitting full images, the IMX500 outputs only metadata, minimizing data size, reducing bandwidth, and lowering costs. +- **Metadata Output:** Instead of transmitting images only, the IMX500 can output both image and metadata (inference result), and can output metadata only for minimizing data size, reducing bandwidth, and lowering costs. - **Addresses Privacy Concerns:** By processing data on the device, the IMX500 addresses privacy concerns, ideal for human-centric applications like person counting and occupancy tracking. - **Real-time Processing:** Fast, on-sensor processing supports real-time decisions, perfect for edge AI applications such as autonomous systems. @@ -247,7 +247,7 @@ Export to IMX500 format has wide applicability across industries. Here are some ## Conclusion -Exporting Ultralytics YOLOv8 models to Sony's IMX500 format allows you to deploy your models for efficient inference on IMX500-based cameras. By leveraging advanced quantization and pruning techniques, you can reduce model size and improve inference speed without significantly compromising accuracy. +Exporting Ultralytics YOLOv8 models to Sony's IMX500 format allows you to deploy your models for efficient inference on IMX500-based cameras. By leveraging advanced quantization techniques, you can reduce model size and improve inference speed without significantly compromising accuracy. For more information and detailed guidelines, refer to Sony's [IMX500 website](https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera). @@ -271,7 +271,7 @@ The export process will create a directory containing the necessary files for de The IMX500 format offers several important advantages for edge deployment: - On-chip AI processing reduces latency and power consumption -- Outputs metadata instead of full images, minimizing bandwidth usage +- Outputs both image and metadata (inference result) instead of images only - Enhanced privacy by processing data locally without cloud dependency - Real-time processing capabilities ideal for time-sensitive applications - Optimized quantization for efficient model deployment on resource-constrained devices diff --git a/docs/en/macros/train-args.md b/docs/en/macros/train-args.md index ede32f910b..0bc48f8117 100644 --- a/docs/en/macros/train-args.md +++ b/docs/en/macros/train-args.md @@ -17,7 +17,6 @@ | `exist_ok` | `False` | If True, allows overwriting of an existing project/name directory. Useful for iterative experimentation without needing to manually clear previous outputs. | | `pretrained` | `True` | Determines whether to start training from a pretrained model. Can be a boolean value or a string path to a specific model from which to load weights. Enhances training efficiency and model performance. | | `optimizer` | `'auto'` | Choice of optimizer for training. Options include `SGD`, `Adam`, `AdamW`, `NAdam`, `RAdam`, `RMSProp` etc., or `auto` for automatic selection based on model configuration. Affects convergence speed and stability. | -| `verbose` | `False` | Enables verbose output during training, providing detailed logs and progress updates. Useful for debugging and closely monitoring the training process. | | `seed` | `0` | Sets the random seed for training, ensuring reproducibility of results across runs with the same configurations. | | `deterministic` | `True` | Forces deterministic algorithm use, ensuring reproducibility but may affect performance and speed due to the restriction on non-deterministic algorithms. | | `single_cls` | `False` | Treats all classes in multi-class datasets as a single class during training. Useful for binary classification tasks or when focusing on object presence rather than classification. | diff --git a/docs/en/reference/solutions/region_counter.md b/docs/en/reference/solutions/region_counter.md new file mode 100644 index 0000000000..0f27adff28 --- /dev/null +++ b/docs/en/reference/solutions/region_counter.md @@ -0,0 +1,16 @@ +--- +description: Explore the Ultralytics Object Counter for real-time video streams. Learn about initializing parameters, tracking objects, and more. +keywords: Ultralytics, Object Counter, Real-time Tracking, Video Stream, Python, Object Detection +--- + +# Reference for `ultralytics/solutions/region_counter.py` + +!!! note + + This file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/region_counter.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/solutions/region_counter.py). If you spot a problem please help fix it by [contributing](https://docs.ultralytics.com/help/contributing/) a [Pull Request](https://github.com/ultralytics/ultralytics/edit/main/ultralytics/solutions/region_counter.py) 🛠️. Thank you 🙏! + +
+ +## ::: ultralytics.solutions.region_counter.RegionCounter + +

diff --git a/docs/mkdocs_github_authors.yaml b/docs/mkdocs_github_authors.yaml index 49360cf687..3e650937fe 100644 --- a/docs/mkdocs_github_authors.yaml +++ b/docs/mkdocs_github_authors.yaml @@ -10,6 +10,9 @@ 130829914+IvorZhu331@users.noreply.github.com: avatar: https://avatars.githubusercontent.com/u/130829914?v=4 username: IvorZhu331 +131249114+ServiAmirPM@users.noreply.github.com: + avatar: https://avatars.githubusercontent.com/u/131249114?v=4 + username: ServiAmirPM 131261051+MatthewNoyce@users.noreply.github.com: avatar: https://avatars.githubusercontent.com/u/131261051?v=4 username: MatthewNoyce diff --git a/mkdocs.yml b/mkdocs.yml index 04d734430d..33bbbc3036 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -571,6 +571,7 @@ nav: - solutions: reference/solutions/solutions.md - speed_estimation: reference/solutions/speed_estimation.md - streamlit_inference: reference/solutions/streamlit_inference.md + - region_counter: reference/solutions/region_counter.md - trackers: - basetrack: reference/trackers/basetrack.md - bot_sort: reference/trackers/bot_sort.md diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 2ff53681a3..ac2c604fd7 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.29" +__version__ = "8.3.31" import os diff --git a/ultralytics/cfg/__init__.py b/ultralytics/cfg/__init__.py index 6de1b9484c..66d5708900 100644 --- a/ultralytics/cfg/__init__.py +++ b/ultralytics/cfg/__init__.py @@ -679,6 +679,9 @@ def handle_yolo_solutions(args: List[str]) -> None: ) s_n = "count" # Default solution if none provided + if args and args[0] == "help": # Add check for return if user call `yolo solutions help` + return + cls, method = SOLUTION_MAP[s_n] # solution class name, method name and default source from ultralytics import solutions # import ultralytics solutions diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py index 029e4471e0..8de0a2e6a1 100644 --- a/ultralytics/engine/results.py +++ b/ultralytics/engine/results.py @@ -750,7 +750,7 @@ class Results(SimpleClass): save_one_box( d.xyxy, self.orig_img.copy(), - file=Path(save_dir) / self.names[int(d.cls)] / f"{Path(file_name)}.jpg", + file=Path(save_dir) / self.names[int(d.cls)] / Path(file_name).with_suffix(".jpg"), BGR=True, ) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index a7e3922b4f..c088111fda 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -279,12 +279,7 @@ class BaseTrainer: # Batch size if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size - self.args.batch = self.batch_size = check_train_batch_size( - model=self.model, - imgsz=self.args.imgsz, - amp=self.amp, - batch=self.batch_size, - ) + self.args.batch = self.batch_size = self.auto_batch() # Dataloaders batch_size = self.batch_size // max(world_size, 1) @@ -478,6 +473,16 @@ class BaseTrainer: self._clear_memory() self.run_callbacks("teardown") + def auto_batch(self, max_num_obj=0): + """Get batch size by calculating memory occupation of model.""" + return check_train_batch_size( + model=self.model, + imgsz=self.args.imgsz, + amp=self.amp, + batch=self.batch_size, + max_num_obj=max_num_obj, + ) # returns batch size + def _get_memory(self): """Get accelerator memory utilization in GB.""" if self.device.type == "mps": diff --git a/ultralytics/models/yolo/detect/train.py b/ultralytics/models/yolo/detect/train.py index e0dbb367f7..606b9fb92b 100644 --- a/ultralytics/models/yolo/detect/train.py +++ b/ultralytics/models/yolo/detect/train.py @@ -141,3 +141,10 @@ class DetectionTrainer(BaseTrainer): boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0) cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0) plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot) + + def auto_batch(self): + """Get batch size by calculating memory occupation of model.""" + train_dataset = self.build_dataset(self.trainset, mode="train", batch=16) + # 4 for mosaic augmentation + max_num_obj = max(len(l["cls"]) for l in train_dataset.labels) * 4 + return super().auto_batch(max_num_obj) diff --git a/ultralytics/solutions/__init__.py b/ultralytics/solutions/__init__.py index 4446c1826e..9de61edce2 100644 --- a/ultralytics/solutions/__init__.py +++ b/ultralytics/solutions/__init__.py @@ -7,6 +7,7 @@ from .heatmap import Heatmap from .object_counter import ObjectCounter from .parking_management import ParkingManagement, ParkingPtsSelection from .queue_management import QueueManager +from .region_counter import RegionCounter from .speed_estimation import SpeedEstimator from .streamlit_inference import inference @@ -21,4 +22,5 @@ __all__ = ( "SpeedEstimator", "Analytics", "inference", + "RegionCounter", ) diff --git a/ultralytics/solutions/analytics.py b/ultralytics/solutions/analytics.py index aed7beed94..9be192448c 100644 --- a/ultralytics/solutions/analytics.py +++ b/ultralytics/solutions/analytics.py @@ -54,7 +54,7 @@ class Analytics(BaseSolution): self.y_label = "Total Counts" # Predefined data - self.bg_color = "#00F344" # background color of frame + self.bg_color = "#F3F3F3" # background color of frame self.fg_color = "#111E68" # foreground color of frame self.title = "Ultralytics Solutions" # window name self.max_points = 45 # maximum points to be drawn on window diff --git a/ultralytics/solutions/region_counter.py b/ultralytics/solutions/region_counter.py new file mode 100644 index 0000000000..03575100d5 --- /dev/null +++ b/ultralytics/solutions/region_counter.py @@ -0,0 +1,112 @@ +# Ultralytics YOLO 🚀, AGPL-3.0 license + +from ultralytics.solutions.solutions import BaseSolution +from ultralytics.utils.plotting import Annotator, colors + + +class RegionCounter(BaseSolution): + """ + A class designed for real-time counting of objects within user-defined regions in a video stream. + + This class inherits from `BaseSolution` and offers functionalities to define polygonal regions in a video + frame, track objects, and count those objects that pass through each defined region. This makes it useful + for applications that require counting in specified areas, such as monitoring zones or segmented sections. + + Attributes: + region_template (dict): A template for creating new counting regions with default attributes including + the name, polygon coordinates, and display colors. + counting_regions (list): A list storing all defined regions, where each entry is based on `region_template` + and includes specific region settings like name, coordinates, and color. + + Methods: + add_region: Adds a new counting region with specified attributes, such as the region's name, polygon points, + region color, and text color. + count: Processes video frames to count objects in each region, drawing regions and displaying counts + on the frame. Handles object detection, region definition, and containment checks. + """ + + def __init__(self, **kwargs): + """Initializes the RegionCounter class for real-time counting in different regions of the video streams.""" + super().__init__(**kwargs) + self.region_template = { + "name": "Default Region", + "polygon": None, + "counts": 0, + "dragging": False, + "region_color": (255, 255, 255), + "text_color": (0, 0, 0), + } + self.counting_regions = [] + + def add_region(self, name, polygon_points, region_color, text_color): + """ + Adds a new region to the counting list based on the provided template with specific attributes. + + Args: + name (str): Name assigned to the new region. + polygon_points (list[tuple]): List of (x, y) coordinates defining the region's polygon. + region_color (tuple): BGR color for region visualization. + text_color (tuple): BGR color for the text within the region. + """ + region = self.region_template.copy() + region.update( + { + "name": name, + "polygon": self.Polygon(polygon_points), + "region_color": region_color, + "text_color": text_color, + } + ) + self.counting_regions.append(region) + + def count(self, im0): + """ + Processes the input frame to detect and count objects within each defined region. + + Args: + im0 (numpy.ndarray): Input image frame where objects and regions are annotated. + + Returns: + im0 (numpy.ndarray): Processed image frame with annotated counting information. + """ + self.annotator = Annotator(im0, line_width=self.line_width) + self.extract_tracks(im0) + + # Region initialization and conversion + if self.region is None: + self.initialize_region() + regions = {"Region#01": self.region} + else: + regions = self.region if isinstance(self.region, dict) else {"Region#01": self.region} + + # Draw regions and process counts for each defined area + for idx, (region_name, reg_pts) in enumerate(regions.items(), start=1): + color = colors(idx, True) + self.annotator.draw_region(reg_pts=reg_pts, color=color, thickness=self.line_width * 2) + self.add_region(region_name, reg_pts, color, self.annotator.get_txt_color()) + + # Prepare regions for containment check + for region in self.counting_regions: + region["prepared_polygon"] = self.prep(region["polygon"]) + + # Process bounding boxes and count objects within each region + for box, cls in zip(self.boxes, self.clss): + self.annotator.box_label(box, label=self.names[cls], color=colors(cls, True)) + bbox_center = ((box[0] + box[2]) / 2, (box[1] + box[3]) / 2) + + for region in self.counting_regions: + if region["prepared_polygon"].contains(self.Point(bbox_center)): + region["counts"] += 1 + + # Display counts in each region + for region in self.counting_regions: + self.annotator.text_label( + region["polygon"].bounds, + label=str(region["counts"]), + color=region["region_color"], + txt_color=region["text_color"], + ) + region["counts"] = 0 # Reset count for next frame + + self.display_output(im0) + return im0 diff --git a/ultralytics/solutions/solutions.py b/ultralytics/solutions/solutions.py index ea94767033..ba3f1ec2cc 100644 --- a/ultralytics/solutions/solutions.py +++ b/ultralytics/solutions/solutions.py @@ -50,10 +50,12 @@ class BaseSolution: """ check_requirements("shapely>=2.0.0") from shapely.geometry import LineString, Point, Polygon + from shapely.prepared import prep self.LineString = LineString self.Polygon = Polygon self.Point = Point + self.prep = prep # Load config and update with args DEFAULT_SOL_DICT.update(kwargs) diff --git a/ultralytics/utils/autobatch.py b/ultralytics/utils/autobatch.py index 6a0d9cbc29..0c3e8e4bd0 100644 --- a/ultralytics/utils/autobatch.py +++ b/ultralytics/utils/autobatch.py @@ -11,7 +11,7 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr from ultralytics.utils.torch_utils import autocast, profile -def check_train_batch_size(model, imgsz=640, amp=True, batch=-1): +def check_train_batch_size(model, imgsz=640, amp=True, batch=-1, max_num_obj=1): """ Compute optimal YOLO training batch size using the autobatch() function. @@ -20,6 +20,7 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1): imgsz (int, optional): Image size used for training. amp (bool, optional): Use automatic mixed precision if True. batch (float, optional): Fraction of GPU memory to use. If -1, use default. + max_num_obj (int, optional): The maximum number of objects from dataset. Returns: (int): Optimal batch size computed using the autobatch() function. @@ -29,10 +30,12 @@ def check_train_batch_size(model, imgsz=640, amp=True, batch=-1): Otherwise, a default fraction of 0.6 is used. """ with autocast(enabled=amp): - return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6) + return autobatch( + deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj + ) -def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): +def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch, max_num_obj=1): """ Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory. @@ -41,6 +44,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): imgsz (int, optional): The image size used as input for the YOLO model. Defaults to 640. fraction (float, optional): The fraction of available CUDA memory to use. Defaults to 0.60. batch_size (int, optional): The default batch size to use if an error is detected. Defaults to 16. + max_num_obj (int, optional): The maximum number of objects from dataset. Returns: (int): The optimal batch size. @@ -70,7 +74,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64] try: img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] - results = profile(img, model, n=1, device=device) + results = profile(img, model, n=1, device=device, max_num_obj=max_num_obj) # Fit a solution y = [x[2] for x in results if x] # memory [2] diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py index 9fb5020923..eec2a3b2d2 100644 --- a/ultralytics/utils/tal.py +++ b/ultralytics/utils/tal.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn +from . import LOGGER from .checks import check_version from .metrics import bbox_iou, probiou from .ops import xywhr2xyxyxyxy @@ -58,17 +59,46 @@ class TaskAlignedAssigner(nn.Module): """ self.bs = pd_scores.shape[0] self.n_max_boxes = gt_bboxes.shape[1] + device = gt_bboxes.device if self.n_max_boxes == 0: - device = gt_bboxes.device return ( - torch.full_like(pd_scores[..., 0], self.bg_idx).to(device), - torch.zeros_like(pd_bboxes).to(device), - torch.zeros_like(pd_scores).to(device), - torch.zeros_like(pd_scores[..., 0]).to(device), - torch.zeros_like(pd_scores[..., 0]).to(device), + torch.full_like(pd_scores[..., 0], self.bg_idx), + torch.zeros_like(pd_bboxes), + torch.zeros_like(pd_scores), + torch.zeros_like(pd_scores[..., 0]), + torch.zeros_like(pd_scores[..., 0]), ) + try: + return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt) + except torch.OutOfMemoryError: + # Move tensors to CPU, compute, then move back to original device + LOGGER.warning("WARNING: CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU") + cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)] + result = self._forward(*cpu_tensors) + return tuple(t.to(device) for t in result) + + def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): + """ + Compute the task-aligned assignment. Reference code is available at + https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py. + + Args: + pd_scores (Tensor): shape(bs, num_total_anchors, num_classes) + pd_bboxes (Tensor): shape(bs, num_total_anchors, 4) + anc_points (Tensor): shape(num_total_anchors, 2) + gt_labels (Tensor): shape(bs, n_max_boxes, 1) + gt_bboxes (Tensor): shape(bs, n_max_boxes, 4) + mask_gt (Tensor): shape(bs, n_max_boxes, 1) + + Returns: + target_labels (Tensor): shape(bs, num_total_anchors) + target_bboxes (Tensor): shape(bs, num_total_anchors, 4) + target_scores (Tensor): shape(bs, num_total_anchors, num_classes) + fg_mask (Tensor): shape(bs, num_total_anchors) + target_gt_idx (Tensor): shape(bs, num_total_anchors) + """ mask_pos, align_metric, overlaps = self.get_pos_mask( pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt ) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 966e980f1b..4bacc79adb 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -623,7 +623,7 @@ def convert_optimizer_state_dict_to_fp16(state_dict): return state_dict -def profile(input, ops, n=10, device=None): +def profile(input, ops, n=10, device=None, max_num_obj=0): """ Ultralytics speed, memory and FLOPs profiler. @@ -671,6 +671,14 @@ def profile(input, ops, n=10, device=None): t[2] = float("nan") tf += (t[1] - t[0]) * 1000 / n # ms per op forward tb += (t[2] - t[1]) * 1000 / n # ms per op backward + if max_num_obj: # simulate training with predictions per image grid (for AutoBatch) + torch.randn( + x.shape[0], + max_num_obj, + int(sum([(x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist()])), + device=device, + dtype=torch.float32, + ) mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB) s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters