From 5b76bed7d02f35584d6c3565dbe13e4ac5137588 Mon Sep 17 00:00:00 2001 From: Semih Demirel <85176438+semihhdemirel@users.noreply.github.com> Date: Tue, 24 Dec 2024 14:18:34 +0300 Subject: [PATCH] [Example] RTDETR-ONNXRuntime-Python (#18369) --- examples/README.md | 1 + examples/RTDETR-ONNXRuntime-Python/README.md | 43 ++++ examples/RTDETR-ONNXRuntime-Python/main.py | 213 +++++++++++++++++++ 3 files changed, 257 insertions(+) create mode 100644 examples/RTDETR-ONNXRuntime-Python/README.md create mode 100644 examples/RTDETR-ONNXRuntime-Python/main.py diff --git a/examples/README.md b/examples/README.md index 76f078bde..ee06d337b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -12,6 +12,7 @@ This directory features a collection of real-world applications and walkthroughs | [YOLO .Net ONNX Detection C#](https://www.nuget.org/packages/Yolov8.Net) | C# .Net | [Samuel Stainback](https://github.com/sstainba) | | [YOLOv8 on NVIDIA Jetson(TensorRT and DeepStream)](https://wiki.seeedstudio.com/YOLOv8-DeepStream-TRT-Jetson/) | Python | [Lakshantha](https://github.com/lakshanthad) | | [YOLOv8 ONNXRuntime Python](./YOLOv8-ONNXRuntime) | Python/ONNXRuntime | [Semih Demirel](https://github.com/semihhdemirel) | +| [RTDETR ONNXRuntime Python](./RTDETR-ONNXRuntime-Python) | Python/ONNXRuntime | [Semih Demirel](https://github.com/semihhdemirel) | | [YOLOv8 ONNXRuntime CPP](./YOLOv8-ONNXRuntime-CPP) | C++/ONNXRuntime | [DennisJcy](https://github.com/DennisJcy), [Onuralp Sezer](https://github.com/onuralpszr) | | [RTDETR ONNXRuntime C#](https://github.com/Kayzwer/yolo-cs/blob/master/RTDETR.cs) | C#/ONNX | [Kayzwer](https://github.com/Kayzwer) | | [YOLOv8 SAHI Video Inference](https://github.com/RizwanMunawar/ultralytics/blob/main/examples/YOLOv8-SAHI-Inference-Video/yolov8_sahi.py) | Python | [Muhammad Rizwan Munawar](https://github.com/RizwanMunawar) | diff --git a/examples/RTDETR-ONNXRuntime-Python/README.md b/examples/RTDETR-ONNXRuntime-Python/README.md new file mode 100644 index 000000000..1861da829 --- /dev/null +++ b/examples/RTDETR-ONNXRuntime-Python/README.md @@ -0,0 +1,43 @@ +# RTDETR - ONNX Runtime + +This project implements RTDETR using ONNX Runtime. + +## Installation + +To run this project, you need to install the required dependencies. The following instructions will guide you through the installation process. + +### Installing Required Dependencies + +You can install the required dependencies by running the following command: + +```bash +pip install -r requirements.txt +``` + +### Installing `onnxruntime-gpu` + +If you have an NVIDIA GPU and want to leverage GPU acceleration, you can install the onnxruntime-gpu package using the following command: + +```bash +pip install onnxruntime-gpu +``` + +Note: Make sure you have the appropriate GPU drivers installed on your system. + +### Installing `onnxruntime` (CPU version) + +If you don't have an NVIDIA GPU or prefer to use the CPU version of onnxruntime, you can install the onnxruntime package using the following command: + +```bash +pip install onnxruntime +``` + +### Usage + +After successfully installing the required packages, you can run the RTDETR implementation using the following command: + +```bash +python main.py --model rtdetr-l.onnx --img image.jpg --conf-thres 0.5 --iou-thres 0.5 +``` + +Make sure to replace rtdetr-l.onnx with the path to your RTDETR ONNX model file, image.jpg with the path to your input image, and adjust the confidence threshold (conf-thres) and IoU threshold (iou-thres) values as needed. diff --git a/examples/RTDETR-ONNXRuntime-Python/main.py b/examples/RTDETR-ONNXRuntime-Python/main.py new file mode 100644 index 000000000..f25ad6082 --- /dev/null +++ b/examples/RTDETR-ONNXRuntime-Python/main.py @@ -0,0 +1,213 @@ +import argparse + +import cv2 +import numpy as np +import onnxruntime as ort +import torch + +from ultralytics.utils import ASSETS, yaml_load +from ultralytics.utils.checks import check_requirements, check_yaml + +class RTDETR: + """RTDETR object detection model class for handling inference and visualization.""" + + def __init__(self, model_path, img_path, conf_thres=0.5, iou_thres=0.5): + """ + Initializes the RTDETR object with the specified parameters. + + Args: + model_path: Path to the ONNX model file. + img_path: Path to the input image. + conf_thres: Confidence threshold for object detection. + iou_thres: IoU threshold for non-maximum suppression + """ + self.model_path = model_path + self.img_path = img_path + self.conf_thres = conf_thres + self.iou_thres = iou_thres + + # Set up the ONNX runtime session with CUDA and CPU execution providers + self.session = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) + self.model_input = self.session.get_inputs() + self.input_width = self.model_input[0].shape[2] + self.input_height = self.model_input[0].shape[3] + + # Load class names from the COCO dataset YAML file + self.classes = yaml_load(check_yaml("coco8.yaml"))["names"] + + # Generate a color palette for drawing bounding boxes + self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3)) + + def draw_detections(self, box, score, class_id): + """ + Draws bounding boxes and labels on the input image based on the detected objects. + + Args: + box: Detected bounding box. + score: Corresponding detection score. + class_id: Class ID for the detected object. + + Returns: + None + """ + # Extract the coordinates of the bounding box + x1, y1, x2, y2 = box + + # Retrieve the color for the class ID + color = self.color_palette[class_id] + + # Draw the bounding box on the image + cv2.rectangle(self.img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2) + + # Create the label text with class name and score + label = f"{self.classes[class_id]}: {score:.2f}" + + # Calculate the dimensions of the label text + (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) + + # Calculate the position of the label text + label_x = x1 + label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10 + + # Draw a filled rectangle as the background for the label text + cv2.rectangle( + self.img, (int(label_x), int(label_y - label_height)), (int(label_x + label_width), int(label_y + label_height)), color, cv2.FILLED + ) + + # Draw the label text on the image + cv2.putText(self.img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) + + def preprocess(self): + """ + Preprocesses the input image before performing inference. + + Returns: + image_data: Preprocessed image data ready for inference. + """ + # Read the input image using OpenCV + self.img = cv2.imread(self.img_path) + + # Get the height and width of the input image + self.img_height, self.img_width = self.img.shape[:2] + + # Convert the image color space from BGR to RGB + img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB) + + # Resize the image to match the input shape + img = cv2.resize(img, (self.input_width, self.input_height)) + + # Normalize the image data by dividing it by 255.0 + image_data = np.array(img) / 255.0 + + # Transpose the image to have the channel dimension as the first dimension + image_data = np.transpose(image_data, (2, 0, 1)) # Channel first + + # Expand the dimensions of the image data to match the expected input shape + image_data = np.expand_dims(image_data, axis=0).astype(np.float32) + + # Return the preprocessed image data + return image_data + + def bbox_cxcywh_to_xyxy(self, boxes): + """ + Converts bounding boxes from (center x, center y, width, height) format + to (x_min, y_min, x_max, y_max) format. + + Args: + boxes (numpy.ndarray): An array of shape (N, 4) where each row represents + a bounding box in (cx, cy, w, h) format. + + Returns: + numpy.ndarray: An array of shape (N, 4) where each row represents + a bounding box in (x_min, y_min, x_max, y_max) format. + """ + # Calculate half width and half height of the bounding boxes + half_width = boxes[:, 2] / 2 + half_height = boxes[:, 3] / 2 + + # Calculate the coordinates of the bounding boxes + x_min = boxes[:, 0] - half_width + y_min = boxes[:, 1] - half_height + x_max = boxes[:, 0] + half_width + y_max = boxes[:, 1] + half_height + + # Return the bounding boxes in (x_min, y_min, x_max, y_max) format + return np.column_stack((x_min, y_min, x_max, y_max)) + + def postprocess(self, model_output): + """ + Postprocesses the model output to extract detections and draw them on the input image. + + Args: + model_output: Output of the model inference. + + Returns: + np.array: Annotated image with detections. + """ + # Squeeze the model output to remove unnecessary dimensions + outputs = np.squeeze(model_output[0]) + + # Extract bounding boxes and scores from the model output + boxes = outputs[:, :4] + scores = outputs[:, 4:] + + # Get the class labels and scores for each detection + labels = np.argmax(scores, axis=1) + scores = np.max(scores, axis=1) + + # Apply confidence threshold to filter out low-confidence detections + mask = scores > self.conf_thres + boxes, scores, labels = boxes[mask], scores[mask], labels[mask] + + # Convert bounding boxes to (x_min, y_min, x_max, y_max) format + boxes = self.bbox_cxcywh_to_xyxy(boxes) + + # Scale bounding boxes to match the original image dimensions + boxes[:, 0::2] *= self.img_width + boxes[:, 1::2] *= self.img_height + + # Draw detections on the image + for box, score, label in zip(boxes, scores, labels): + self.draw_detections(box, score, label) + + # Return the annotated image + return self.img + + def main(self): + """ + Executes the detection on the input image using the ONNX model. + + Returns: + np.array: Output image with annotations. + """ + # Preprocess the image for model input + image_data = self.preprocess() + + # Run the model inference + model_output = self.session.run(None, {self.model_input[0].name: image_data}) + + # Process and return the model output + return self.postprocess(model_output) + +if __name__ == "__main__": + # Set up argument parser for command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, default="rtdetr-l.onnx", help="Path to the ONNX model file.") + parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to the input image.") + parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold for object detection.") + parser.add_argument("--iou-thres", type=float, default=0.5, help="IoU threshold for non-maximum suppression.") + args = parser.parse_args() + + # Check for dependencies and set up ONNX runtime + check_requirements("onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime") + + # Create the detector instance with specified parameters + detection = RTDETR(args.model, args.img, args.conf_thres, args.iou_thres) + + # Perform detection and get the output image + output_image = detection.main() + + # Display the annotated output image + cv2.namedWindow("Output", cv2.WINDOW_NORMAL) + cv2.imshow("Output", output_image) + cv2.waitKey(0) \ No newline at end of file