diff --git a/examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml b/examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml new file mode 100644 index 0000000000..101f72efb7 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "yolov8-rs" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +clap = { version = "4.2.4", features = ["derive"] } +image = { version = "0.24.7", default-features = false, features = ["jpeg", "png", "webp-encoder"] } +imageproc = { version = "0.23.0", default-features = false } +ndarray = { version = "0.15.6" } +ort = {version = "1.16.3", default-features = false, features = ["load-dynamic", "copy-dylibs", "half"]} +rusttype = { version = "0.9", default-features = false } +anyhow = { version = "1.0.75"} +regex = { version = "1.5.4" } +rand = { version ="0.8.5" } +chrono = { version = "0.4.30" } +half = { version = "2.3.1" } +dirs = { version = "5.0.1" } +ureq = { version = "2.9.1" } diff --git a/examples/YOLOv8-ONNXRuntime-Rust/README.md b/examples/YOLOv8-ONNXRuntime-Rust/README.md new file mode 100644 index 0000000000..6876c15e91 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/README.md @@ -0,0 +1,222 @@ +# YOLOv8-ONNXRuntime-Rust for All the Key YOLO Tasks + +This repository provides a Rust demo for performing YOLOv8 tasks like `Classification`, `Segmentation`, `Detection` and `Pose Detection` using ONNXRuntime. + +## Features + +- Support `Classification`, `Segmentation`, `Detection`, `Pose(Keypoints)-Detection` tasks. +- Support `FP16` & `FP32` ONNX models. +- Support `CPU`, `CUDA` and `TensorRT` execution provider to accelerate computation. +- Support dynamic input shapes(`batch`, `width`, `height`). + +## Installation + +### 1. Install Rust + +Please follow the Rust official installation. (https://www.rust-lang.org/tools/install) + +### 2. Install ONNXRuntime + +This repository use `ort` crate, which is ONNXRuntime wrapper for Rust. (https://docs.rs/ort/latest/ort/) + +You can follow the instruction with `ort` doc or simply do this: + +- step1: Download ONNXRuntime(https://github.com/microsoft/onnxruntime/releases) +- setp2: Set environment variable `PATH` for linking. + +On ubuntu, You can do like this: + +``` +vim ~/.bashrc + +# Add the path of ONNXRUntime lib +export LD_LIBRARY_PATH=/home/qweasd/Documents/onnxruntime-linux-x64-gpu-1.16.3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}} + +source ~/.bashrc +``` + +### 3. \[Optional\] Install CUDA & CuDNN & TensorRT + +- CUDA execution provider requires CUDA v11.6+. +- TensorRT execution provider requires CUDA v11.4+ and TensorRT v8.4+. + +## Get Started + +### 1. Export the YOLOv8 ONNX Models + +```bash +pip install -U ultralytics + +# export onnx model with dynamic shapes +yolo export model=yolov8m.pt format=onnx simplify dynamic +yolo export model=yolov8m-cls.pt format=onnx simplify dynamic +yolo export model=yolov8m-pose.pt format=onnx simplify dynamic +yolo export model=yolov8m-seg.pt format=onnx simplify dynamic + + +# export onnx model with constant shapes +yolo export model=yolov8m.pt format=onnx simplify +yolo export model=yolov8m-cls.pt format=onnx simplify +yolo export model=yolov8m-pose.pt format=onnx simplify +yolo export model=yolov8m-seg.pt format=onnx simplify +``` + +### 2. Run Inference + +It will perform inference with the ONNX model on the source image. + +``` +cargo run --release -- --model --source +``` + +Set `--cuda` to use CUDA execution provider to speed up inference. + +``` +cargo run --release -- --cuda --model --source +``` + +Set `--trt` to use TensorRT execution provider, and you can set `--fp16` at the same time to use TensorRT FP16 engine. + +``` +cargo run --release -- --trt --fp16 --model --source +``` + +Set `--device_id` to select which device to run. When you have only one GPU, and you set `device_id` to 1 will not cause program panic, the `ort` would automatically fall back to `CPU` EP. + +``` +cargo run --release -- --cuda --device_id 0 --model --source +``` + +Set `--batch` to do multi-batch-size inference. + +If you're using `--trt`, you can also set `--batch-min` and `--batch-max` to explicitly specify min/max/opt batch for dynamic batch input.(https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#explicit-shape-range-for-dynamic-shape-input).(Note that the ONNX model should exported with dynamic shapes) + +``` +cargo run --release -- --cuda --batch 2 --model --source +``` + +Set `--height` and `--width` to do dynamic image size inference. (Note that the ONNX model should exported with dynamic shapes) + +``` +cargo run --release -- --cuda --width 480 --height 640 --model --source +``` + +Set `--profile` to check time consumed in each stage.(Note that the model usually needs to take 1~3 times dry run to warmup. Make sure to run enough times to evaluate the result.) + +``` +cargo run --release -- --trt --fp16 --profile --model --source +``` + +Results: (yolov8m.onnx, batch=1, 3 times, trt, fp16, RTX 3060Ti) + +``` +==> 0 +[Model Preprocess]: 12.75788ms +[ORT H2D]: 237.118µs +[ORT Inference]: 507.895469ms +[ORT D2H]: 191.655µs +[Model Inference]: 508.34589ms +[Model Postprocess]: 1.061122ms +==> 1 +[Model Preprocess]: 13.658655ms +[ORT H2D]: 209.975µs +[ORT Inference]: 5.12372ms +[ORT D2H]: 182.389µs +[Model Inference]: 5.530022ms +[Model Postprocess]: 1.04851ms +==> 2 +[Model Preprocess]: 12.475332ms +[ORT H2D]: 246.127µs +[ORT Inference]: 5.048432ms +[ORT D2H]: 187.117µs +[Model Inference]: 5.493119ms +[Model Postprocess]: 1.040906ms +``` + +And also: + +`--conf`: confidence threshold \[default: 0.3\] + +`--iou`: iou threshold in NMS \[default: 0.45\] + +`--kconf`: confidence threshold of keypoint \[default: 0.55\] + +`--plot`: plot inference result with random RGB color and save + +you can check out all CLI arguments by: + +``` +git clone https://github.com/ultralytics/ultralytics +cd ultralytics/examples/YOLOv8-ONNXRuntime-Rust +cargo run --release -- --help +``` + +## Examples + +### Classification + +Running dynamic shape ONNX model on `CPU` with image size `--height 224 --width 224`. +Saving plotted image in `runs` directory. + +``` +cargo run --release -- --model ../assets/weights/yolov8m-cls-dyn.onnx --source ../assets/images/dog.jpg --height 224 --width 224 --plot --profile +``` + +You will see result like: + +``` +Summary: +> Task: Classify (Ultralytics 8.0.217) +> EP: Cpu +> Dtype: Float32 +> Batch: 1 (Dynamic), Height: 224 (Dynamic), Width: 224 (Dynamic) +> nc: 1000 nk: 0, nm: 0, conf: 0.3, kconf: 0.55, iou: 0.45 + +[Model Preprocess]: 16.363477ms +[ORT H2D]: 50.722µs +[ORT Inference]: 16.295808ms +[ORT D2H]: 8.37µs +[Model Inference]: 16.367046ms +[Model Postprocess]: 3.527µs +[ + YOLOResult { + Probs(top5): Some([(208, 0.6950566), (209, 0.13823675), (178, 0.04849795), (215, 0.019029364), (212, 0.016506357)]), + Bboxes: None, + Keypoints: None, + Masks: None, + }, +] + +``` + +![2023-11-25-22-02-02-156623351](https://github.com/jamjamjon/ultralytics/assets/51357717/ef75c2ae-c5ab-44cc-9d9e-e60b51e39662) + +### Object Detection + +Using `CUDA` EP and dynamic image size `--height 640 --width 480` + +``` +cargo run --release -- --cuda --model ../assets/weights/yolov8m-dynamic.onnx --source ../assets/images/bus.jpg --plot --height 640 --width 480 +``` + +![det](https://github.com/jamjamjon/ultralytics/assets/51357717/5d89a19d-0c96-4a59-875c-defab6887a2c) + +### Pose Detection + +using `TensorRT` EP + +``` +cargo run --release -- --trt --model ../assets/weights/yolov8m-pose.onnx --source ../assets/images/bus.jpg --plot +``` + +![2023-11-25-22-31-45-127054025](https://github.com/jamjamjon/ultralytics/assets/51357717/157b5ba7-bfcf-47cf-bee7-68b62e0de1c4) + +### Instance Segmentation + +using `TensorRT` EP and FP16 model `--fp16` + +``` +cargo run --release -- --trt --fp16 --model ../assets/weights/yolov8m-seg.onnx --source ../assets/images/0172.jpg --plot +``` + +![seg](https://github.com/jamjamjon/ultralytics/assets/51357717/cf046f4f-9533-478a-adc7-4de22443a641) diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/cli.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/cli.rs new file mode 100644 index 0000000000..2ba0dd49ec --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/cli.rs @@ -0,0 +1,87 @@ +use clap::Parser; + +use crate::YOLOTask; + +#[derive(Parser, Clone)] +#[command(author, version, about, long_about = None)] +pub struct Args { + /// ONNX model path + #[arg(long, required = true)] + pub model: String, + + /// input path + #[arg(long, required = true)] + pub source: String, + + /// device id + #[arg(long, default_value_t = 0)] + pub device_id: u32, + + /// using TensorRT EP + #[arg(long)] + pub trt: bool, + + /// using CUDA EP + #[arg(long)] + pub cuda: bool, + + /// input batch size + #[arg(long, default_value_t = 1)] + pub batch: u32, + + /// trt input min_batch size + #[arg(long, default_value_t = 1)] + pub batch_min: u32, + + /// trt input max_batch size + #[arg(long, default_value_t = 32)] + pub batch_max: u32, + + /// using TensorRT --fp16 + #[arg(long)] + pub fp16: bool, + + /// specify YOLO task + #[arg(long, value_enum)] + pub task: Option, + + /// num_classes + #[arg(long)] + pub nc: Option, + + /// num_keypoints + #[arg(long)] + pub nk: Option, + + /// num_masks + #[arg(long)] + pub nm: Option, + + /// input image width + #[arg(long)] + pub width: Option, + + /// input image height + #[arg(long)] + pub height: Option, + + /// confidence threshold + #[arg(long, required = false, default_value_t = 0.3)] + pub conf: f32, + + /// iou threshold in NMS + #[arg(long, required = false, default_value_t = 0.45)] + pub iou: f32, + + /// confidence threshold of keypoint + #[arg(long, required = false, default_value_t = 0.55)] + pub kconf: f32, + + /// plot inference result and save + #[arg(long)] + pub plot: bool, + + /// check time consumed in each stage + #[arg(long)] + pub profile: bool, +} diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/lib.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/lib.rs new file mode 100644 index 0000000000..1af7f7c5e1 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/lib.rs @@ -0,0 +1,119 @@ +#![allow(clippy::type_complexity)] + +use std::io::{Read, Write}; + +pub mod cli; +pub mod model; +pub mod ort_backend; +pub mod yolo_result; +pub use crate::cli::Args; +pub use crate::model::YOLOv8; +pub use crate::ort_backend::{Batch, OrtBackend, OrtConfig, OrtEP, YOLOTask}; +pub use crate::yolo_result::{Bbox, Embedding, Point2, YOLOResult}; + +pub fn non_max_suppression( + xs: &mut Vec<(Bbox, Option>, Option>)>, + iou_threshold: f32, +) { + xs.sort_by(|b1, b2| b2.0.confidence().partial_cmp(&b1.0.confidence()).unwrap()); + + let mut current_index = 0; + for index in 0..xs.len() { + let mut drop = false; + for prev_index in 0..current_index { + let iou = xs[prev_index].0.iou(&xs[index].0); + if iou > iou_threshold { + drop = true; + break; + } + } + if !drop { + xs.swap(current_index, index); + current_index += 1; + } + } + xs.truncate(current_index); +} + +pub fn gen_time_string(delimiter: &str) -> String { + let offset = chrono::FixedOffset::east_opt(8 * 60 * 60).unwrap(); // Beijing + let t_now = chrono::Utc::now().with_timezone(&offset); + let fmt = format!( + "%Y{}%m{}%d{}%H{}%M{}%S{}%f", + delimiter, delimiter, delimiter, delimiter, delimiter, delimiter + ); + t_now.format(&fmt).to_string() +} + +pub const SKELETON: [(usize, usize); 16] = [ + (0, 1), + (0, 2), + (1, 3), + (2, 4), + (5, 6), + (5, 11), + (6, 12), + (11, 12), + (5, 7), + (6, 8), + (7, 9), + (8, 10), + (11, 13), + (12, 14), + (13, 15), + (14, 16), +]; + +pub fn check_font(font: &str) -> rusttype::Font<'static> { + // check then load font + + // ultralytics font path + let font_path_config = match dirs::config_dir() { + Some(mut d) => { + d.push("Ultralytics"); + d.push(font); + d + } + None => panic!("Unsupported operating system. Now support Linux, MacOS, Windows."), + }; + + // current font path + let font_path_current = std::path::PathBuf::from(font); + + // check font + let font_path = if font_path_config.exists() { + font_path_config + } else if font_path_current.exists() { + font_path_current + } else { + println!("Downloading font..."); + let source_url = "https://ultralytics.com/assets/Arial.ttf"; + let resp = ureq::get(source_url) + .timeout(std::time::Duration::from_secs(500)) + .call() + .unwrap_or_else(|err| panic!("> Failed to download font: {source_url}: {err:?}")); + + // read to buffer + let mut buffer = vec![]; + let total_size = resp + .header("Content-Length") + .and_then(|s| s.parse::().ok()) + .unwrap(); + let _reader = resp + .into_reader() + .take(total_size) + .read_to_end(&mut buffer) + .unwrap(); + + // save + let _path = std::fs::File::create(font).unwrap(); + let mut writer = std::io::BufWriter::new(_path); + writer.write_all(&buffer).unwrap(); + println!("Font saved at: {:?}", font_path_current.display()); + font_path_current + }; + + // load font + let buffer = std::fs::read(font_path).unwrap(); + rusttype::Font::try_from_vec(buffer).unwrap() +} diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/main.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/main.rs new file mode 100644 index 0000000000..8dd1567990 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/main.rs @@ -0,0 +1,28 @@ +use clap::Parser; + +use yolov8_rs::{Args, YOLOv8}; + +fn main() -> Result<(), Box> { + let args = Args::parse(); + + // 1. load image + let x = image::io::Reader::open(&args.source)? + .with_guessed_format()? + .decode()?; + + // 2. model support dynamic batch inference, so input should be a Vec + let xs = vec![x]; + + // You can test `--batch 2` with this + // let xs = vec![x.clone(), x]; + + // 3. build yolov8 model + let mut model = YOLOv8::new(args)?; + model.summary(); // model info + + // 4. run + let ys = model.run(&xs)?; + println!("{:?}", ys); + + Ok(()) +} diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/model.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/model.rs new file mode 100644 index 0000000000..1c0e5e494d --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/model.rs @@ -0,0 +1,642 @@ +#![allow(clippy::type_complexity)] + +use anyhow::Result; +use image::{DynamicImage, GenericImageView, ImageBuffer}; +use ndarray::{s, Array, Axis, IxDyn}; +use rand::{thread_rng, Rng}; +use std::path::PathBuf; + +use crate::{ + check_font, gen_time_string, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend, + OrtConfig, OrtEP, Point2, YOLOResult, YOLOTask, SKELETON, +}; + +pub struct YOLOv8 { + // YOLOv8 model for all yolo-tasks + engine: OrtBackend, + nc: u32, + nk: u32, + nm: u32, + height: u32, + width: u32, + batch: u32, + task: YOLOTask, + conf: f32, + kconf: f32, + iou: f32, + names: Vec, + color_palette: Vec<(u8, u8, u8)>, + profile: bool, + plot: bool, +} + +impl YOLOv8 { + pub fn new(config: Args) -> Result { + // execution provider + let ep = if config.trt { + OrtEP::Trt(config.device_id) + } else if config.cuda { + OrtEP::Cuda(config.device_id) + } else { + OrtEP::Cpu + }; + + // batch + let batch = Batch { + opt: config.batch, + min: config.batch_min, + max: config.batch_max, + }; + + // build ort engine + let ort_args = OrtConfig { + ep, + batch, + f: config.model, + task: config.task, + trt_fp16: config.fp16, + image_size: (config.height, config.width), + }; + let engine = OrtBackend::build(ort_args)?; + + // get batch, height, width, tasks, nc, nk, nm + let (batch, height, width, task) = ( + engine.batch(), + engine.height(), + engine.width(), + engine.task(), + ); + let nc = engine.nc().or(config.nc).unwrap_or_else(|| { + panic!("Failed to get num_classes, make it explicit with `--nc`"); + }); + let (nk, nm) = match task { + YOLOTask::Pose => { + let nk = engine.nk().or(config.nk).unwrap_or_else(|| { + panic!("Failed to get num_keypoints, make it explicit with `--nk`"); + }); + (nk, 0) + } + YOLOTask::Segment => { + let nm = engine.nm().or(config.nm).unwrap_or_else(|| { + panic!("Failed to get num_masks, make it explicit with `--nm`"); + }); + (0, nm) + } + _ => (0, 0), + }; + + // class names + let names = engine.names().unwrap_or(vec!["Unknown".to_string()]); + + // color palette + let mut rng = thread_rng(); + let color_palette: Vec<_> = names + .iter() + .map(|_| { + ( + rng.gen_range(0..=255), + rng.gen_range(0..=255), + rng.gen_range(0..=255), + ) + }) + .collect(); + + Ok(Self { + engine, + names, + conf: config.conf, + kconf: config.kconf, + iou: config.iou, + color_palette, + profile: config.profile, + plot: config.plot, + nc, + nk, + nm, + height, + width, + batch, + task, + }) + } + + pub fn scale_wh(&self, w0: f32, h0: f32, w1: f32, h1: f32) -> (f32, f32, f32) { + let r = (w1 / w0).min(h1 / h0); + (r, (w0 * r).round(), (h0 * r).round()) + } + + pub fn preprocess(&mut self, xs: &Vec) -> Result> { + let mut ys = + Array::ones((xs.len(), 3, self.height() as usize, self.width() as usize)).into_dyn(); + ys.fill(144.0 / 255.0); + for (idx, x) in xs.iter().enumerate() { + let img = match self.task() { + YOLOTask::Classify => x.resize_exact( + self.width(), + self.height(), + image::imageops::FilterType::Triangle, + ), + _ => { + let (w0, h0) = x.dimensions(); + let w0 = w0 as f32; + let h0 = h0 as f32; + let (_, w_new, h_new) = + self.scale_wh(w0, h0, self.width() as f32, self.height() as f32); // f32 round + x.resize_exact( + w_new as u32, + h_new as u32, + if let YOLOTask::Segment = self.task() { + image::imageops::FilterType::CatmullRom + } else { + image::imageops::FilterType::Triangle + }, + ) + } + }; + + for (x, y, rgb) in img.pixels() { + let x = x as usize; + let y = y as usize; + let [r, g, b, _] = rgb.0; + ys[[idx, 0, y, x]] = (r as f32) / 255.0; + ys[[idx, 1, y, x]] = (g as f32) / 255.0; + ys[[idx, 2, y, x]] = (b as f32) / 255.0; + } + } + + Ok(ys) + } + + pub fn run(&mut self, xs: &Vec) -> Result> { + // pre-process + let t_pre = std::time::Instant::now(); + let xs_ = self.preprocess(xs)?; + if self.profile { + println!("[Model Preprocess]: {:?}", t_pre.elapsed()); + } + + // run + let t_run = std::time::Instant::now(); + let ys = self.engine.run(xs_, self.profile)?; + if self.profile { + println!("[Model Inference]: {:?}", t_run.elapsed()); + } + + // post-process + let t_post = std::time::Instant::now(); + let ys = self.postprocess(ys, xs)?; + if self.profile { + println!("[Model Postprocess]: {:?}", t_post.elapsed()); + } + + // plot and save + if self.plot { + self.plot_and_save(&ys, xs, Some(&SKELETON)); + } + Ok(ys) + } + + pub fn postprocess( + &self, + xs: Vec>, + xs0: &[DynamicImage], + ) -> Result> { + if let YOLOTask::Classify = self.task() { + let mut ys = Vec::new(); + let preds = &xs[0]; + for batch in preds.axis_iter(Axis(0)) { + ys.push(YOLOResult::new( + Some(Embedding::new(batch.into_owned())), + None, + None, + None, + )); + } + Ok(ys) + } else { + const CXYWH_OFFSET: usize = 4; // cxcywh + const KPT_STEP: usize = 3; // xyconf + let preds = &xs[0]; + let protos = { + if xs.len() > 1 { + Some(&xs[1]) + } else { + None + } + }; + let mut ys = Vec::new(); + for (idx, anchor) in preds.axis_iter(Axis(0)).enumerate() { + // [bs, 4 + nc + nm, anchors] + // input image + let width_original = xs0[idx].width() as f32; + let height_original = xs0[idx].height() as f32; + let ratio = (self.width() as f32 / width_original) + .min(self.height() as f32 / height_original); + + // save each result + let mut data: Vec<(Bbox, Option>, Option>)> = Vec::new(); + for pred in anchor.axis_iter(Axis(1)) { + // split preds for different tasks + let bbox = pred.slice(s![0..CXYWH_OFFSET]); + let clss = pred.slice(s![CXYWH_OFFSET..CXYWH_OFFSET + self.nc() as usize]); + let kpts = { + if let YOLOTask::Pose = self.task() { + Some(pred.slice(s![pred.len() - KPT_STEP * self.nk() as usize..])) + } else { + None + } + }; + let coefs = { + if let YOLOTask::Segment = self.task() { + Some(pred.slice(s![pred.len() - self.nm() as usize..]).to_vec()) + } else { + None + } + }; + + // confidence and id + let (id, &confidence) = clss + .into_iter() + .enumerate() + .reduce(|max, x| if x.1 > max.1 { x } else { max }) + .unwrap(); // definitely will not panic! + + // confidence filter + if confidence < self.conf { + continue; + } + + // bbox re-scale + let cx = bbox[0] / ratio; + let cy = bbox[1] / ratio; + let w = bbox[2] / ratio; + let h = bbox[3] / ratio; + let x = cx - w / 2.; + let y = cy - h / 2.; + let y_bbox = Bbox::new( + x.max(0.0f32).min(width_original), + y.max(0.0f32).min(height_original), + w, + h, + id, + confidence, + ); + + // kpts + let y_kpts = { + if let Some(kpts) = kpts { + let mut kpts_ = Vec::new(); + // rescale + for i in 0..self.nk() as usize { + let kx = kpts[KPT_STEP * i] / ratio; + let ky = kpts[KPT_STEP * i + 1] / ratio; + let kconf = kpts[KPT_STEP * i + 2]; + if kconf < self.kconf { + kpts_.push(Point2::default()); + } else { + kpts_.push(Point2::new_with_conf( + kx.max(0.0f32).min(width_original), + ky.max(0.0f32).min(height_original), + kconf, + )); + } + } + Some(kpts_) + } else { + None + } + }; + + // data merged + data.push((y_bbox, y_kpts, coefs)); + } + + // nms + non_max_suppression(&mut data, self.iou); + + // decode + let mut y_bboxes: Vec = Vec::new(); + let mut y_kpts: Vec> = Vec::new(); + let mut y_masks: Vec> = Vec::new(); + for elem in data.into_iter() { + if let Some(kpts) = elem.1 { + y_kpts.push(kpts) + } + + // decode masks + if let Some(coefs) = elem.2 { + let proto = protos.unwrap().slice(s![idx, .., .., ..]); + let (nm, nh, nw) = proto.dim(); + + // coefs * proto -> mask + let coefs = Array::from_shape_vec((1, nm), coefs)?; // (n, nm) + let proto = proto.to_owned().into_shape((nm, nh * nw))?; // (nm, nh*nw) + let mask = coefs.dot(&proto).into_shape((nh, nw, 1))?; // (nh, nw, n) + + // build image from ndarray + let mask_im: ImageBuffer, Vec> = + match ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec()) { + Some(image) => image, + None => panic!("can not create image from ndarray"), + }; + let mut mask_im = image::DynamicImage::from(mask_im); // -> dyn + + // rescale masks + let (_, w_mask, h_mask) = + self.scale_wh(width_original, height_original, nw as f32, nh as f32); + let mask_cropped = mask_im.crop(0, 0, w_mask as u32, h_mask as u32); + let mask_original = mask_cropped.resize_exact( + // resize_to_fill + width_original as u32, + height_original as u32, + match self.task() { + YOLOTask::Segment => image::imageops::FilterType::CatmullRom, + _ => image::imageops::FilterType::Triangle, + }, + ); + + // crop-mask with bbox + let mut mask_original_cropped = mask_original.into_luma8(); + for y in 0..height_original as usize { + for x in 0..width_original as usize { + if x < elem.0.xmin() as usize + || x > elem.0.xmax() as usize + || y < elem.0.ymin() as usize + || y > elem.0.ymax() as usize + { + mask_original_cropped.put_pixel( + x as u32, + y as u32, + image::Luma([0u8]), + ); + } + } + } + y_masks.push(mask_original_cropped.into_raw()); + } + y_bboxes.push(elem.0); + } + + // save each result + let y = YOLOResult { + probs: None, + bboxes: if !y_bboxes.is_empty() { + Some(y_bboxes) + } else { + None + }, + keypoints: if !y_kpts.is_empty() { + Some(y_kpts) + } else { + None + }, + masks: if !y_masks.is_empty() { + Some(y_masks) + } else { + None + }, + }; + ys.push(y); + } + + Ok(ys) + } + } + + pub fn plot_and_save( + &self, + ys: &[YOLOResult], + xs0: &[DynamicImage], + skeletons: Option<&[(usize, usize)]>, + ) { + // check font then load + let font = check_font("Arial.ttf"); + for (_idb, (img0, y)) in xs0.iter().zip(ys.iter()).enumerate() { + let mut img = img0.to_rgb8(); + + // draw for classifier + if let Some(probs) = y.probs() { + for (i, k) in probs.topk(5).iter().enumerate() { + let legend = format!("{} {:.2}%", self.names[k.0], k.1); + let scale = 32; + let legend_size = img.width().max(img.height()) / scale; + let x = img.width() / 20; + let y = img.height() / 20 + i as u32 * legend_size; + imageproc::drawing::draw_text_mut( + &mut img, + image::Rgb([0, 255, 0]), + x as i32, + y as i32, + rusttype::Scale::uniform(legend_size as f32 - 1.), + &font, + &legend, + ); + } + } + + // draw bboxes & keypoints + if let Some(bboxes) = y.bboxes() { + for (_idx, bbox) in bboxes.iter().enumerate() { + // rect + imageproc::drawing::draw_hollow_rect_mut( + &mut img, + imageproc::rect::Rect::at(bbox.xmin() as i32, bbox.ymin() as i32) + .of_size(bbox.width() as u32, bbox.height() as u32), + image::Rgb(self.color_palette[bbox.id()].into()), + ); + + // text + let legend = format!("{} {:.2}%", self.names[bbox.id()], bbox.confidence()); + let scale = 40; + let legend_size = img.width().max(img.height()) / scale; + imageproc::drawing::draw_text_mut( + &mut img, + image::Rgb(self.color_palette[bbox.id()].into()), + bbox.xmin() as i32, + (bbox.ymin() - legend_size as f32) as i32, + rusttype::Scale::uniform(legend_size as f32 - 1.), + &font, + &legend, + ); + } + } + + // draw kpts + if let Some(keypoints) = y.keypoints() { + for kpts in keypoints.iter() { + for kpt in kpts.iter() { + // filter + if kpt.confidence() < self.kconf { + continue; + } + + // draw point + imageproc::drawing::draw_filled_circle_mut( + &mut img, + (kpt.x() as i32, kpt.y() as i32), + 2, + image::Rgb([0, 255, 0]), + ); + } + + // draw skeleton if has + if let Some(skeletons) = skeletons { + for &(idx1, idx2) in skeletons.iter() { + let kpt1 = &kpts[idx1]; + let kpt2 = &kpts[idx2]; + if kpt1.confidence() < self.kconf || kpt2.confidence() < self.kconf { + continue; + } + imageproc::drawing::draw_line_segment_mut( + &mut img, + (kpt1.x(), kpt1.y()), + (kpt2.x(), kpt2.y()), + image::Rgb([233, 14, 57]), + ); + } + } + } + } + + // draw mask + if let Some(masks) = y.masks() { + for (mask, _bbox) in masks.iter().zip(y.bboxes().unwrap().iter()) { + let mask_nd: ImageBuffer, Vec> = + match ImageBuffer::from_vec(img.width(), img.height(), mask.to_vec()) { + Some(image) => image, + None => panic!("can not crate image from ndarray"), + }; + + for _x in 0..img.width() { + for _y in 0..img.height() { + let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_nd, _x, _y); + if mask_p.0[0] > 0 { + let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, _x, _y); + // img_p.0[2] = self.color_palette[bbox.id()].2 / 2; + // img_p.0[1] = self.color_palette[bbox.id()].1 / 2; + // img_p.0[0] = self.color_palette[bbox.id()].0 / 2; + img_p.0[2] /= 2; + img_p.0[1] = 255 - (255 - img_p.0[2]) / 2; + img_p.0[0] /= 2; + imageproc::drawing::Canvas::draw_pixel(&mut img, _x, _y, img_p) + } + } + } + } + } + + // mkdir and save + let mut runs = PathBuf::from("runs"); + if !runs.exists() { + std::fs::create_dir_all(&runs).unwrap(); + } + runs.push(gen_time_string("-")); + let saveout = format!("{}.jpg", runs.to_str().unwrap()); + let _ = img.save(saveout); + } + } + + pub fn summary(&self) { + println!( + "\nSummary:\n\ + > Task: {:?}{}\n\ + > EP: {:?} {}\n\ + > Dtype: {:?}\n\ + > Batch: {} ({}), Height: {} ({}), Width: {} ({})\n\ + > nc: {} nk: {}, nm: {}, conf: {}, kconf: {}, iou: {}\n\ + ", + self.task(), + match self.engine.author().zip(self.engine.version()) { + Some((author, ver)) => format!(" ({} {})", author, ver), + None => String::from(""), + }, + self.engine.ep(), + if let OrtEP::Cpu = self.engine.ep() { + "" + } else { + "(May still fall back to CPU)" + }, + self.engine.dtype(), + self.batch(), + if self.engine.is_batch_dynamic() { + "Dynamic" + } else { + "Const" + }, + self.height(), + if self.engine.is_height_dynamic() { + "Dynamic" + } else { + "Const" + }, + self.width(), + if self.engine.is_width_dynamic() { + "Dynamic" + } else { + "Const" + }, + self.nc(), + self.nk(), + self.nm(), + self.conf, + self.kconf, + self.iou, + ); + } + + pub fn engine(&self) -> &OrtBackend { + &self.engine + } + + pub fn conf(&self) -> f32 { + self.conf + } + + pub fn set_conf(&mut self, val: f32) { + self.conf = val; + } + + pub fn conf_mut(&mut self) -> &mut f32 { + &mut self.conf + } + + pub fn kconf(&self) -> f32 { + self.kconf + } + + pub fn iou(&self) -> f32 { + self.iou + } + + pub fn task(&self) -> &YOLOTask { + &self.task + } + + pub fn batch(&self) -> u32 { + self.batch + } + + pub fn width(&self) -> u32 { + self.width + } + + pub fn height(&self) -> u32 { + self.height + } + + pub fn nc(&self) -> u32 { + self.nc + } + + pub fn nk(&self) -> u32 { + self.nk + } + + pub fn nm(&self) -> u32 { + self.nm + } + + pub fn names(&self) -> &Vec { + &self.names + } +} diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/ort_backend.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/ort_backend.rs new file mode 100644 index 0000000000..5be93bdc58 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/ort_backend.rs @@ -0,0 +1,534 @@ +use anyhow::Result; +use clap::ValueEnum; +use half::f16; +use ndarray::{Array, CowArray, IxDyn}; +use ort::execution_providers::{CUDAExecutionProviderOptions, TensorRTExecutionProviderOptions}; +use ort::tensor::TensorElementDataType; +use ort::{Environment, ExecutionProvider, Session, SessionBuilder, Value}; +use regex::Regex; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] +pub enum YOLOTask { + // YOLO tasks + Classify, + Detect, + Pose, + Segment, +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum OrtEP { + // ONNXRuntime execution provider + Cpu, + Cuda(u32), + Trt(u32), +} + +#[derive(Debug)] +pub struct Batch { + pub opt: u32, + pub min: u32, + pub max: u32, +} + +impl Default for Batch { + fn default() -> Self { + Self { + opt: 1, + min: 1, + max: 1, + } + } +} + +#[derive(Debug, Default)] +pub struct OrtInputs { + // ONNX model inputs attrs + pub shapes: Vec>, + pub dtypes: Vec, + pub names: Vec, + pub sizes: Vec>, +} + +impl OrtInputs { + pub fn new(session: &Session) -> Self { + let mut shapes = Vec::new(); + let mut dtypes = Vec::new(); + let mut names = Vec::new(); + for i in session.inputs.iter() { + let shape: Vec = i + .dimensions() + .map(|x| if let Some(x) = x { x as i32 } else { -1i32 }) + .collect(); + shapes.push(shape); + dtypes.push(i.input_type); + names.push(i.name.clone()); + } + Self { + shapes, + dtypes, + names, + ..Default::default() + } + } +} + +#[derive(Debug)] +pub struct OrtConfig { + // ORT config + pub f: String, + pub task: Option, + pub ep: OrtEP, + pub trt_fp16: bool, + pub batch: Batch, + pub image_size: (Option, Option), +} + +#[derive(Debug)] +pub struct OrtBackend { + // ORT engine + session: Session, + task: YOLOTask, + ep: OrtEP, + batch: Batch, + inputs: OrtInputs, +} + +impl OrtBackend { + pub fn build(args: OrtConfig) -> Result { + // build env & session + let env = Environment::builder() + .with_name("YOLOv8") + .with_log_level(ort::LoggingLevel::Verbose) + .build()? + .into_arc(); + let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?; + + // get inputs + let mut inputs = OrtInputs::new(&session); + + // batch size + let mut batch = args.batch; + let batch = if inputs.shapes[0][0] == -1 { + batch + } else { + assert_eq!( + inputs.shapes[0][0] as u32, batch.opt, + "Expected batch size: {}, got {}. Try using `--batch {}`.", + inputs.shapes[0][0] as u32, batch.opt, inputs.shapes[0][0] as u32 + ); + batch.opt = inputs.shapes[0][0] as u32; + batch + }; + + // input size: height and width + let height = if inputs.shapes[0][2] == -1 { + match args.image_size.0 { + Some(height) => height, + None => panic!("Failed to get model height. Make it explicit with `--height`"), + } + } else { + inputs.shapes[0][2] as u32 + }; + let width = if inputs.shapes[0][3] == -1 { + match args.image_size.1 { + Some(width) => width, + None => panic!("Failed to get model width. Make it explicit with `--width`"), + } + } else { + inputs.shapes[0][3] as u32 + }; + inputs.sizes.push(vec![height, width]); + + // build provider + let (ep, provider) = match args.ep { + OrtEP::Cuda(device_id) => Self::set_ep_cuda(device_id), + OrtEP::Trt(device_id) => Self::set_ep_trt(device_id, args.trt_fp16, &batch, &inputs), + _ => (OrtEP::Cpu, ExecutionProvider::CPU(Default::default())), + }; + + // build session again with the new provider + let session = SessionBuilder::new(&env)? + // .with_optimization_level(ort::GraphOptimizationLevel::Level3)? + .with_execution_providers([provider])? + .with_model_from_file(args.f)?; + + // task: using given one or guessing + let task = match args.task { + Some(task) => task, + None => match session.metadata() { + Err(_) => panic!("No metadata found. Try making it explicit by `--task`"), + Ok(metadata) => match metadata.custom("task") { + Err(_) => panic!("Can not get custom value. Try making it explicit by `--task`"), + Ok(value) => match value { + None => panic!("No correspoing value of `task` found in metadata. Make it explicit by `--task`"), + Some(task) => match task.as_str() { + "classify" => YOLOTask::Classify, + "detect" => YOLOTask::Detect, + "pose" => YOLOTask::Pose, + "segment" => YOLOTask::Segment, + x => todo!("{:?} is not supported for now!", x), + }, + }, + }, + }, + }; + + Ok(Self { + session, + task, + ep, + batch, + inputs, + }) + } + + pub fn fetch_inputs_from_session( + session: &Session, + ) -> (Vec>, Vec, Vec) { + // get inputs attrs from ONNX model + let mut shapes = Vec::new(); + let mut dtypes = Vec::new(); + let mut names = Vec::new(); + for i in session.inputs.iter() { + let shape: Vec = i + .dimensions() + .map(|x| if let Some(x) = x { x as i32 } else { -1i32 }) + .collect(); + shapes.push(shape); + dtypes.push(i.input_type); + names.push(i.name.clone()); + } + (shapes, dtypes, names) + } + + pub fn set_ep_cuda(device_id: u32) -> (OrtEP, ExecutionProvider) { + // set CUDA + if ExecutionProvider::CUDA(Default::default()).is_available() { + ( + OrtEP::Cuda(device_id), + ExecutionProvider::CUDA(CUDAExecutionProviderOptions { + device_id, + ..Default::default() + }), + ) + } else { + println!("> CUDA is not available! Using CPU."); + (OrtEP::Cpu, ExecutionProvider::CPU(Default::default())) + } + } + + pub fn set_ep_trt( + device_id: u32, + fp16: bool, + batch: &Batch, + inputs: &OrtInputs, + ) -> (OrtEP, ExecutionProvider) { + // set TensorRT + if ExecutionProvider::TensorRT(Default::default()).is_available() { + let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]); + + // dtype match checking + if inputs.dtypes[0] == TensorElementDataType::Float16 && !fp16 { + panic!( + "Dtype mismatch! Expected: Float32, got: {:?}. You should use `--fp16`", + inputs.dtypes[0] + ); + } + + // dynamic shape: input_tensor_1:dim_1xdim_2x...,input_tensor_2:dim_3xdim_4x...,... + let mut opt_string = String::new(); + let mut min_string = String::new(); + let mut max_string = String::new(); + for name in inputs.names.iter() { + let s_opt = format!("{}:{}x3x{}x{},", name, batch.opt, height, width); + let s_min = format!("{}:{}x3x{}x{},", name, batch.min, height, width); + let s_max = format!("{}:{}x3x{}x{},", name, batch.max, height, width); + opt_string.push_str(s_opt.as_str()); + min_string.push_str(s_min.as_str()); + max_string.push_str(s_max.as_str()); + } + let _ = opt_string.pop(); + let _ = min_string.pop(); + let _ = max_string.pop(); + ( + OrtEP::Trt(device_id), + ExecutionProvider::TensorRT(TensorRTExecutionProviderOptions { + device_id, + fp16_enable: fp16, + timing_cache_enable: true, + profile_min_shapes: min_string, + profile_max_shapes: max_string, + profile_opt_shapes: opt_string, + ..Default::default() + }), + ) + } else { + println!("> TensorRT is not available! Try using CUDA..."); + Self::set_ep_cuda(device_id) + } + } + + pub fn fetch_from_metadata(&self, key: &str) -> Option { + // fetch value from onnx model file by key + match self.session.metadata() { + Err(_) => None, + Ok(metadata) => match metadata.custom(key) { + Err(_) => None, + Ok(value) => value, + }, + } + } + + pub fn run(&self, xs: Array, profile: bool) -> Result>> { + // ORT inference + match self.dtype() { + TensorElementDataType::Float16 => self.run_fp16(xs, profile), + TensorElementDataType::Float32 => self.run_fp32(xs, profile), + _ => todo!(), + } + } + + pub fn run_fp16(&self, xs: Array, profile: bool) -> Result>> { + // f32->f16 + let t = std::time::Instant::now(); + let xs = xs.mapv(f16::from_f32); + if profile { + println!("[ORT f32->f16]: {:?}", t.elapsed()); + } + + // h2d + let t = std::time::Instant::now(); + let xs = CowArray::from(xs); + let xs = vec![Value::from_array(self.session.allocator(), &xs)?]; + if profile { + println!("[ORT H2D]: {:?}", t.elapsed()); + } + + // run + let t = std::time::Instant::now(); + let ys = self.session.run(xs)?; + if profile { + println!("[ORT Inference]: {:?}", t.elapsed()); + } + + // d2h + Ok(ys + .iter() + .map(|x| { + // d2h + let t = std::time::Instant::now(); + let x = x.try_extract::<_>().unwrap().view().clone().into_owned(); + if profile { + println!("[ORT D2H]: {:?}", t.elapsed()); + } + + // f16->f32 + let t_ = std::time::Instant::now(); + let x = x.mapv(f16::to_f32); + if profile { + println!("[ORT f16->f32]: {:?}", t_.elapsed()); + } + x + }) + .collect::>>()) + } + + pub fn run_fp32(&self, xs: Array, profile: bool) -> Result>> { + // h2d + let t = std::time::Instant::now(); + let xs = CowArray::from(xs); + let xs = vec![Value::from_array(self.session.allocator(), &xs)?]; + if profile { + println!("[ORT H2D]: {:?}", t.elapsed()); + } + + // run + let t = std::time::Instant::now(); + let ys = self.session.run(xs)?; + if profile { + println!("[ORT Inference]: {:?}", t.elapsed()); + } + + // d2h + Ok(ys + .iter() + .map(|x| { + let t = std::time::Instant::now(); + let x = x.try_extract::<_>().unwrap().view().clone().into_owned(); + if profile { + println!("[ORT D2H]: {:?}", t.elapsed()); + } + x + }) + .collect::>>()) + } + + pub fn output_shapes(&self) -> Vec> { + let mut shapes = Vec::new(); + for o in &self.session.outputs { + let shape: Vec<_> = o + .dimensions() + .map(|x| if let Some(x) = x { x as i32 } else { -1i32 }) + .collect(); + shapes.push(shape); + } + shapes + } + + pub fn output_dtypes(&self) -> Vec { + let mut dtypes = Vec::new(); + self.session + .outputs + .iter() + .for_each(|x| dtypes.push(x.output_type)); + dtypes + } + + pub fn input_shapes(&self) -> &Vec> { + &self.inputs.shapes + } + + pub fn input_names(&self) -> &Vec { + &self.inputs.names + } + + pub fn input_dtypes(&self) -> &Vec { + &self.inputs.dtypes + } + + pub fn dtype(&self) -> TensorElementDataType { + self.input_dtypes()[0] + } + + pub fn height(&self) -> u32 { + self.inputs.sizes[0][0] + } + + pub fn width(&self) -> u32 { + self.inputs.sizes[0][1] + } + + pub fn is_height_dynamic(&self) -> bool { + self.input_shapes()[0][2] == -1 + } + + pub fn is_width_dynamic(&self) -> bool { + self.input_shapes()[0][3] == -1 + } + + pub fn batch(&self) -> u32 { + self.batch.opt + } + + pub fn is_batch_dynamic(&self) -> bool { + self.input_shapes()[0][0] == -1 + } + + pub fn ep(&self) -> &OrtEP { + &self.ep + } + + pub fn task(&self) -> YOLOTask { + self.task.clone() + } + + pub fn names(&self) -> Option> { + // class names, metadata parsing + // String format: `{0: 'person', 1: 'bicycle', 2: 'sports ball', ..., 27: "yellow_lady's_slipper"}` + match self.fetch_from_metadata("names") { + Some(names) => { + let re = Regex::new(r#"(['"])([-()\w '"]+)(['"])"#).unwrap(); + let mut names_ = vec![]; + for (_, [_, name, _]) in re.captures_iter(&names).map(|x| x.extract()) { + names_.push(name.to_string()); + } + Some(names_) + } + None => None, + } + } + + pub fn nk(&self) -> Option { + // num_keypoints, metadata parsing: String `nk` in onnx model: `[17, 3]` + match self.fetch_from_metadata("kpt_shape") { + None => None, + Some(kpt_string) => { + let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap(); + let caps = re.captures(&kpt_string).unwrap(); + Some(caps.get(1).unwrap().as_str().parse::().unwrap()) + } + } + } + + pub fn nc(&self) -> Option { + // num_classes + match self.names() { + // by names + Some(names) => Some(names.len() as u32), + None => match self.task() { + // by task calculation + YOLOTask::Classify => Some(self.output_shapes()[0][1] as u32), + YOLOTask::Detect => { + if self.output_shapes()[0][1] == -1 { + None + } else { + // cxywhclss + Some(self.output_shapes()[0][1] as u32 - 4) + } + } + YOLOTask::Pose => { + match self.nk() { + None => None, + Some(nk) => { + if self.output_shapes()[0][1] == -1 { + None + } else { + // cxywhclss3*kpt + Some(self.output_shapes()[0][1] as u32 - 4 - 3 * nk) + } + } + } + } + YOLOTask::Segment => { + if self.output_shapes()[0][1] == -1 { + None + } else { + // cxywhclssnm + Some((self.output_shapes()[0][1] - self.output_shapes()[1][1]) as u32 - 4) + } + } + }, + } + } + + pub fn nm(&self) -> Option { + // num_masks + match self.task() { + YOLOTask::Segment => Some(self.output_shapes()[1][1] as u32), + _ => None, + } + } + + pub fn na(&self) -> Option { + // num_anchors + match self.task() { + YOLOTask::Segment | YOLOTask::Detect | YOLOTask::Pose => { + if self.output_shapes()[0][2] == -1 { + None + } else { + Some(self.output_shapes()[0][2] as u32) + } + } + _ => None, + } + } + + pub fn author(&self) -> Option { + self.fetch_from_metadata("author") + } + + pub fn version(&self) -> Option { + self.fetch_from_metadata("version") + } +} diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/yolo_result.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/yolo_result.rs new file mode 100644 index 0000000000..2fcc6d8602 --- /dev/null +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/yolo_result.rs @@ -0,0 +1,235 @@ +use ndarray::{Array, Axis, IxDyn}; + +#[derive(Clone, PartialEq, Default)] +pub struct YOLOResult { + // YOLO tasks results of an image + pub probs: Option, + pub bboxes: Option>, + pub keypoints: Option>>, + pub masks: Option>>, +} + +impl std::fmt::Debug for YOLOResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("YOLOResult") + .field( + "Probs(top5)", + &format_args!("{:?}", self.probs().map(|probs| probs.topk(5))), + ) + .field("Bboxes", &self.bboxes) + .field("Keypoints", &self.keypoints) + .field( + "Masks", + &format_args!("{:?}", self.masks().map(|masks| masks.len())), + ) + .finish() + } +} + +impl YOLOResult { + pub fn new( + probs: Option, + bboxes: Option>, + keypoints: Option>>, + masks: Option>>, + ) -> Self { + Self { + probs, + bboxes, + keypoints, + masks, + } + } + + pub fn probs(&self) -> Option<&Embedding> { + self.probs.as_ref() + } + + pub fn keypoints(&self) -> Option<&Vec>> { + self.keypoints.as_ref() + } + + pub fn masks(&self) -> Option<&Vec>> { + self.masks.as_ref() + } + + pub fn bboxes(&self) -> Option<&Vec> { + self.bboxes.as_ref() + } + + pub fn bboxes_mut(&mut self) -> Option<&mut Vec> { + self.bboxes.as_mut() + } +} + +#[derive(Debug, PartialEq, Clone, Default)] +pub struct Point2 { + // A point2d with x, y, conf + x: f32, + y: f32, + confidence: f32, +} + +impl Point2 { + pub fn new_with_conf(x: f32, y: f32, confidence: f32) -> Self { + Self { x, y, confidence } + } + + pub fn new(x: f32, y: f32) -> Self { + Self { + x, + y, + ..Default::default() + } + } + + pub fn x(&self) -> f32 { + self.x + } + + pub fn y(&self) -> f32 { + self.y + } + + pub fn confidence(&self) -> f32 { + self.confidence + } +} + +#[derive(Debug, Clone, PartialEq, Default)] +pub struct Embedding { + // An float32 n-dims tensor + data: Array, +} + +impl Embedding { + pub fn new(data: Array) -> Self { + Self { data } + } + + pub fn data(&self) -> &Array { + &self.data + } + + pub fn topk(&self, k: usize) -> Vec<(usize, f32)> { + let mut probs = self + .data + .iter() + .enumerate() + .map(|(a, b)| (a, *b)) + .collect::>(); + probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); + let mut topk = Vec::new(); + for &(id, confidence) in probs.iter().take(k) { + topk.push((id, confidence)); + } + topk + } + + pub fn norm(&self) -> Array { + let std_ = self.data.mapv(|x| x * x).sum_axis(Axis(0)).mapv(f32::sqrt); + self.data.clone() / std_ + } + + pub fn top1(&self) -> (usize, f32) { + self.topk(1)[0] + } +} + +#[derive(Debug, Clone, PartialEq, Default)] +pub struct Bbox { + // a bounding box around an object + xmin: f32, + ymin: f32, + width: f32, + height: f32, + id: usize, + confidence: f32, +} + +impl Bbox { + pub fn new_from_xywh(xmin: f32, ymin: f32, width: f32, height: f32) -> Self { + Self { + xmin, + ymin, + width, + height, + ..Default::default() + } + } + + pub fn new(xmin: f32, ymin: f32, width: f32, height: f32, id: usize, confidence: f32) -> Self { + Self { + xmin, + ymin, + width, + height, + id, + confidence, + } + } + + pub fn width(&self) -> f32 { + self.width + } + + pub fn height(&self) -> f32 { + self.height + } + + pub fn xmin(&self) -> f32 { + self.xmin + } + + pub fn ymin(&self) -> f32 { + self.ymin + } + + pub fn xmax(&self) -> f32 { + self.xmin + self.width + } + + pub fn ymax(&self) -> f32 { + self.ymin + self.height + } + + pub fn tl(&self) -> Point2 { + Point2::new(self.xmin, self.ymin) + } + + pub fn br(&self) -> Point2 { + Point2::new(self.xmax(), self.ymax()) + } + + pub fn cxcy(&self) -> Point2 { + Point2::new(self.xmin + self.width / 2., self.ymin + self.height / 2.) + } + + pub fn id(&self) -> usize { + self.id + } + + pub fn confidence(&self) -> f32 { + self.confidence + } + + pub fn area(&self) -> f32 { + self.width * self.height + } + + pub fn intersection_area(&self, another: &Bbox) -> f32 { + let l = self.xmin.max(another.xmin); + let r = (self.xmin + self.width).min(another.xmin + another.width); + let t = self.ymin.max(another.ymin); + let b = (self.ymin + self.height).min(another.ymin + another.height); + (r - l + 1.).max(0.) * (b - t + 1.).max(0.) + } + + pub fn union(&self, another: &Bbox) -> f32 { + self.area() + another.area() - self.intersection_area(another) + } + + pub fn iou(&self, another: &Bbox) -> f32 { + self.intersection_area(another) / self.union(another) + } +}