[Example] YOLOv8-ONNXRuntime-Rust example (#6583)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/6396/head^2
jamjamjon 1 year ago committed by GitHub
parent 3c277347e4
commit fdcf0dd4fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 21
      examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml
  2. 222
      examples/YOLOv8-ONNXRuntime-Rust/README.md
  3. 87
      examples/YOLOv8-ONNXRuntime-Rust/src/cli.rs
  4. 119
      examples/YOLOv8-ONNXRuntime-Rust/src/lib.rs
  5. 28
      examples/YOLOv8-ONNXRuntime-Rust/src/main.rs
  6. 642
      examples/YOLOv8-ONNXRuntime-Rust/src/model.rs
  7. 534
      examples/YOLOv8-ONNXRuntime-Rust/src/ort_backend.rs
  8. 235
      examples/YOLOv8-ONNXRuntime-Rust/src/yolo_result.rs

@ -0,0 +1,21 @@
[package]
name = "yolov8-rs"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
clap = { version = "4.2.4", features = ["derive"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png", "webp-encoder"] }
imageproc = { version = "0.23.0", default-features = false }
ndarray = { version = "0.15.6" }
ort = {version = "1.16.3", default-features = false, features = ["load-dynamic", "copy-dylibs", "half"]}
rusttype = { version = "0.9", default-features = false }
anyhow = { version = "1.0.75"}
regex = { version = "1.5.4" }
rand = { version ="0.8.5" }
chrono = { version = "0.4.30" }
half = { version = "2.3.1" }
dirs = { version = "5.0.1" }
ureq = { version = "2.9.1" }

@ -0,0 +1,222 @@
# YOLOv8-ONNXRuntime-Rust for All the Key YOLO Tasks
This repository provides a Rust demo for performing YOLOv8 tasks like `Classification`, `Segmentation`, `Detection` and `Pose Detection` using ONNXRuntime.
## Features
- Support `Classification`, `Segmentation`, `Detection`, `Pose(Keypoints)-Detection` tasks.
- Support `FP16` & `FP32` ONNX models.
- Support `CPU`, `CUDA` and `TensorRT` execution provider to accelerate computation.
- Support dynamic input shapes(`batch`, `width`, `height`).
## Installation
### 1. Install Rust
Please follow the Rust official installation. (https://www.rust-lang.org/tools/install)
### 2. Install ONNXRuntime
This repository use `ort` crate, which is ONNXRuntime wrapper for Rust. (https://docs.rs/ort/latest/ort/)
You can follow the instruction with `ort` doc or simply do this:
- step1: Download ONNXRuntime(https://github.com/microsoft/onnxruntime/releases)
- setp2: Set environment variable `PATH` for linking.
On ubuntu, You can do like this:
```
vim ~/.bashrc
# Add the path of ONNXRUntime lib
export LD_LIBRARY_PATH=/home/qweasd/Documents/onnxruntime-linux-x64-gpu-1.16.3/lib${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
source ~/.bashrc
```
### 3. \[Optional\] Install CUDA & CuDNN & TensorRT
- CUDA execution provider requires CUDA v11.6+.
- TensorRT execution provider requires CUDA v11.4+ and TensorRT v8.4+.
## Get Started
### 1. Export the YOLOv8 ONNX Models
```bash
pip install -U ultralytics
# export onnx model with dynamic shapes
yolo export model=yolov8m.pt format=onnx simplify dynamic
yolo export model=yolov8m-cls.pt format=onnx simplify dynamic
yolo export model=yolov8m-pose.pt format=onnx simplify dynamic
yolo export model=yolov8m-seg.pt format=onnx simplify dynamic
# export onnx model with constant shapes
yolo export model=yolov8m.pt format=onnx simplify
yolo export model=yolov8m-cls.pt format=onnx simplify
yolo export model=yolov8m-pose.pt format=onnx simplify
yolo export model=yolov8m-seg.pt format=onnx simplify
```
### 2. Run Inference
It will perform inference with the ONNX model on the source image.
```
cargo run --release -- --model <MODEL> --source <SOURCE>
```
Set `--cuda` to use CUDA execution provider to speed up inference.
```
cargo run --release -- --cuda --model <MODEL> --source <SOURCE>
```
Set `--trt` to use TensorRT execution provider, and you can set `--fp16` at the same time to use TensorRT FP16 engine.
```
cargo run --release -- --trt --fp16 --model <MODEL> --source <SOURCE>
```
Set `--device_id` to select which device to run. When you have only one GPU, and you set `device_id` to 1 will not cause program panic, the `ort` would automatically fall back to `CPU` EP.
```
cargo run --release -- --cuda --device_id 0 --model <MODEL> --source <SOURCE>
```
Set `--batch` to do multi-batch-size inference.
If you're using `--trt`, you can also set `--batch-min` and `--batch-max` to explicitly specify min/max/opt batch for dynamic batch input.(https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#explicit-shape-range-for-dynamic-shape-input).(Note that the ONNX model should exported with dynamic shapes)
```
cargo run --release -- --cuda --batch 2 --model <MODEL> --source <SOURCE>
```
Set `--height` and `--width` to do dynamic image size inference. (Note that the ONNX model should exported with dynamic shapes)
```
cargo run --release -- --cuda --width 480 --height 640 --model <MODEL> --source <SOURCE>
```
Set `--profile` to check time consumed in each stage.(Note that the model usually needs to take 1~3 times dry run to warmup. Make sure to run enough times to evaluate the result.)
```
cargo run --release -- --trt --fp16 --profile --model <MODEL> --source <SOURCE>
```
Results: (yolov8m.onnx, batch=1, 3 times, trt, fp16, RTX 3060Ti)
```
==> 0
[Model Preprocess]: 12.75788ms
[ORT H2D]: 237.118µs
[ORT Inference]: 507.895469ms
[ORT D2H]: 191.655µs
[Model Inference]: 508.34589ms
[Model Postprocess]: 1.061122ms
==> 1
[Model Preprocess]: 13.658655ms
[ORT H2D]: 209.975µs
[ORT Inference]: 5.12372ms
[ORT D2H]: 182.389µs
[Model Inference]: 5.530022ms
[Model Postprocess]: 1.04851ms
==> 2
[Model Preprocess]: 12.475332ms
[ORT H2D]: 246.127µs
[ORT Inference]: 5.048432ms
[ORT D2H]: 187.117µs
[Model Inference]: 5.493119ms
[Model Postprocess]: 1.040906ms
```
And also:
`--conf`: confidence threshold \[default: 0.3\]
`--iou`: iou threshold in NMS \[default: 0.45\]
`--kconf`: confidence threshold of keypoint \[default: 0.55\]
`--plot`: plot inference result with random RGB color and save
you can check out all CLI arguments by:
```
git clone https://github.com/ultralytics/ultralytics
cd ultralytics/examples/YOLOv8-ONNXRuntime-Rust
cargo run --release -- --help
```
## Examples
### Classification
Running dynamic shape ONNX model on `CPU` with image size `--height 224 --width 224`.
Saving plotted image in `runs` directory.
```
cargo run --release -- --model ../assets/weights/yolov8m-cls-dyn.onnx --source ../assets/images/dog.jpg --height 224 --width 224 --plot --profile
```
You will see result like:
```
Summary:
> Task: Classify (Ultralytics 8.0.217)
> EP: Cpu
> Dtype: Float32
> Batch: 1 (Dynamic), Height: 224 (Dynamic), Width: 224 (Dynamic)
> nc: 1000 nk: 0, nm: 0, conf: 0.3, kconf: 0.55, iou: 0.45
[Model Preprocess]: 16.363477ms
[ORT H2D]: 50.722µs
[ORT Inference]: 16.295808ms
[ORT D2H]: 8.37µs
[Model Inference]: 16.367046ms
[Model Postprocess]: 3.527µs
[
YOLOResult {
Probs(top5): Some([(208, 0.6950566), (209, 0.13823675), (178, 0.04849795), (215, 0.019029364), (212, 0.016506357)]),
Bboxes: None,
Keypoints: None,
Masks: None,
},
]
```
![2023-11-25-22-02-02-156623351](https://github.com/jamjamjon/ultralytics/assets/51357717/ef75c2ae-c5ab-44cc-9d9e-e60b51e39662)
### Object Detection
Using `CUDA` EP and dynamic image size `--height 640 --width 480`
```
cargo run --release -- --cuda --model ../assets/weights/yolov8m-dynamic.onnx --source ../assets/images/bus.jpg --plot --height 640 --width 480
```
![det](https://github.com/jamjamjon/ultralytics/assets/51357717/5d89a19d-0c96-4a59-875c-defab6887a2c)
### Pose Detection
using `TensorRT` EP
```
cargo run --release -- --trt --model ../assets/weights/yolov8m-pose.onnx --source ../assets/images/bus.jpg --plot
```
![2023-11-25-22-31-45-127054025](https://github.com/jamjamjon/ultralytics/assets/51357717/157b5ba7-bfcf-47cf-bee7-68b62e0de1c4)
### Instance Segmentation
using `TensorRT` EP and FP16 model `--fp16`
```
cargo run --release -- --trt --fp16 --model ../assets/weights/yolov8m-seg.onnx --source ../assets/images/0172.jpg --plot
```
![seg](https://github.com/jamjamjon/ultralytics/assets/51357717/cf046f4f-9533-478a-adc7-4de22443a641)

@ -0,0 +1,87 @@
use clap::Parser;
use crate::YOLOTask;
#[derive(Parser, Clone)]
#[command(author, version, about, long_about = None)]
pub struct Args {
/// ONNX model path
#[arg(long, required = true)]
pub model: String,
/// input path
#[arg(long, required = true)]
pub source: String,
/// device id
#[arg(long, default_value_t = 0)]
pub device_id: u32,
/// using TensorRT EP
#[arg(long)]
pub trt: bool,
/// using CUDA EP
#[arg(long)]
pub cuda: bool,
/// input batch size
#[arg(long, default_value_t = 1)]
pub batch: u32,
/// trt input min_batch size
#[arg(long, default_value_t = 1)]
pub batch_min: u32,
/// trt input max_batch size
#[arg(long, default_value_t = 32)]
pub batch_max: u32,
/// using TensorRT --fp16
#[arg(long)]
pub fp16: bool,
/// specify YOLO task
#[arg(long, value_enum)]
pub task: Option<YOLOTask>,
/// num_classes
#[arg(long)]
pub nc: Option<u32>,
/// num_keypoints
#[arg(long)]
pub nk: Option<u32>,
/// num_masks
#[arg(long)]
pub nm: Option<u32>,
/// input image width
#[arg(long)]
pub width: Option<u32>,
/// input image height
#[arg(long)]
pub height: Option<u32>,
/// confidence threshold
#[arg(long, required = false, default_value_t = 0.3)]
pub conf: f32,
/// iou threshold in NMS
#[arg(long, required = false, default_value_t = 0.45)]
pub iou: f32,
/// confidence threshold of keypoint
#[arg(long, required = false, default_value_t = 0.55)]
pub kconf: f32,
/// plot inference result and save
#[arg(long)]
pub plot: bool,
/// check time consumed in each stage
#[arg(long)]
pub profile: bool,
}

@ -0,0 +1,119 @@
#![allow(clippy::type_complexity)]
use std::io::{Read, Write};
pub mod cli;
pub mod model;
pub mod ort_backend;
pub mod yolo_result;
pub use crate::cli::Args;
pub use crate::model::YOLOv8;
pub use crate::ort_backend::{Batch, OrtBackend, OrtConfig, OrtEP, YOLOTask};
pub use crate::yolo_result::{Bbox, Embedding, Point2, YOLOResult};
pub fn non_max_suppression(
xs: &mut Vec<(Bbox, Option<Vec<Point2>>, Option<Vec<f32>>)>,
iou_threshold: f32,
) {
xs.sort_by(|b1, b2| b2.0.confidence().partial_cmp(&b1.0.confidence()).unwrap());
let mut current_index = 0;
for index in 0..xs.len() {
let mut drop = false;
for prev_index in 0..current_index {
let iou = xs[prev_index].0.iou(&xs[index].0);
if iou > iou_threshold {
drop = true;
break;
}
}
if !drop {
xs.swap(current_index, index);
current_index += 1;
}
}
xs.truncate(current_index);
}
pub fn gen_time_string(delimiter: &str) -> String {
let offset = chrono::FixedOffset::east_opt(8 * 60 * 60).unwrap(); // Beijing
let t_now = chrono::Utc::now().with_timezone(&offset);
let fmt = format!(
"%Y{}%m{}%d{}%H{}%M{}%S{}%f",
delimiter, delimiter, delimiter, delimiter, delimiter, delimiter
);
t_now.format(&fmt).to_string()
}
pub const SKELETON: [(usize, usize); 16] = [
(0, 1),
(0, 2),
(1, 3),
(2, 4),
(5, 6),
(5, 11),
(6, 12),
(11, 12),
(5, 7),
(6, 8),
(7, 9),
(8, 10),
(11, 13),
(12, 14),
(13, 15),
(14, 16),
];
pub fn check_font(font: &str) -> rusttype::Font<'static> {
// check then load font
// ultralytics font path
let font_path_config = match dirs::config_dir() {
Some(mut d) => {
d.push("Ultralytics");
d.push(font);
d
}
None => panic!("Unsupported operating system. Now support Linux, MacOS, Windows."),
};
// current font path
let font_path_current = std::path::PathBuf::from(font);
// check font
let font_path = if font_path_config.exists() {
font_path_config
} else if font_path_current.exists() {
font_path_current
} else {
println!("Downloading font...");
let source_url = "https://ultralytics.com/assets/Arial.ttf";
let resp = ureq::get(source_url)
.timeout(std::time::Duration::from_secs(500))
.call()
.unwrap_or_else(|err| panic!("> Failed to download font: {source_url}: {err:?}"));
// read to buffer
let mut buffer = vec![];
let total_size = resp
.header("Content-Length")
.and_then(|s| s.parse::<u64>().ok())
.unwrap();
let _reader = resp
.into_reader()
.take(total_size)
.read_to_end(&mut buffer)
.unwrap();
// save
let _path = std::fs::File::create(font).unwrap();
let mut writer = std::io::BufWriter::new(_path);
writer.write_all(&buffer).unwrap();
println!("Font saved at: {:?}", font_path_current.display());
font_path_current
};
// load font
let buffer = std::fs::read(font_path).unwrap();
rusttype::Font::try_from_vec(buffer).unwrap()
}

@ -0,0 +1,28 @@
use clap::Parser;
use yolov8_rs::{Args, YOLOv8};
fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
// 1. load image
let x = image::io::Reader::open(&args.source)?
.with_guessed_format()?
.decode()?;
// 2. model support dynamic batch inference, so input should be a Vec
let xs = vec![x];
// You can test `--batch 2` with this
// let xs = vec![x.clone(), x];
// 3. build yolov8 model
let mut model = YOLOv8::new(args)?;
model.summary(); // model info
// 4. run
let ys = model.run(&xs)?;
println!("{:?}", ys);
Ok(())
}

@ -0,0 +1,642 @@
#![allow(clippy::type_complexity)]
use anyhow::Result;
use image::{DynamicImage, GenericImageView, ImageBuffer};
use ndarray::{s, Array, Axis, IxDyn};
use rand::{thread_rng, Rng};
use std::path::PathBuf;
use crate::{
check_font, gen_time_string, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend,
OrtConfig, OrtEP, Point2, YOLOResult, YOLOTask, SKELETON,
};
pub struct YOLOv8 {
// YOLOv8 model for all yolo-tasks
engine: OrtBackend,
nc: u32,
nk: u32,
nm: u32,
height: u32,
width: u32,
batch: u32,
task: YOLOTask,
conf: f32,
kconf: f32,
iou: f32,
names: Vec<String>,
color_palette: Vec<(u8, u8, u8)>,
profile: bool,
plot: bool,
}
impl YOLOv8 {
pub fn new(config: Args) -> Result<Self> {
// execution provider
let ep = if config.trt {
OrtEP::Trt(config.device_id)
} else if config.cuda {
OrtEP::Cuda(config.device_id)
} else {
OrtEP::Cpu
};
// batch
let batch = Batch {
opt: config.batch,
min: config.batch_min,
max: config.batch_max,
};
// build ort engine
let ort_args = OrtConfig {
ep,
batch,
f: config.model,
task: config.task,
trt_fp16: config.fp16,
image_size: (config.height, config.width),
};
let engine = OrtBackend::build(ort_args)?;
// get batch, height, width, tasks, nc, nk, nm
let (batch, height, width, task) = (
engine.batch(),
engine.height(),
engine.width(),
engine.task(),
);
let nc = engine.nc().or(config.nc).unwrap_or_else(|| {
panic!("Failed to get num_classes, make it explicit with `--nc`");
});
let (nk, nm) = match task {
YOLOTask::Pose => {
let nk = engine.nk().or(config.nk).unwrap_or_else(|| {
panic!("Failed to get num_keypoints, make it explicit with `--nk`");
});
(nk, 0)
}
YOLOTask::Segment => {
let nm = engine.nm().or(config.nm).unwrap_or_else(|| {
panic!("Failed to get num_masks, make it explicit with `--nm`");
});
(0, nm)
}
_ => (0, 0),
};
// class names
let names = engine.names().unwrap_or(vec!["Unknown".to_string()]);
// color palette
let mut rng = thread_rng();
let color_palette: Vec<_> = names
.iter()
.map(|_| {
(
rng.gen_range(0..=255),
rng.gen_range(0..=255),
rng.gen_range(0..=255),
)
})
.collect();
Ok(Self {
engine,
names,
conf: config.conf,
kconf: config.kconf,
iou: config.iou,
color_palette,
profile: config.profile,
plot: config.plot,
nc,
nk,
nm,
height,
width,
batch,
task,
})
}
pub fn scale_wh(&self, w0: f32, h0: f32, w1: f32, h1: f32) -> (f32, f32, f32) {
let r = (w1 / w0).min(h1 / h0);
(r, (w0 * r).round(), (h0 * r).round())
}
pub fn preprocess(&mut self, xs: &Vec<DynamicImage>) -> Result<Array<f32, IxDyn>> {
let mut ys =
Array::ones((xs.len(), 3, self.height() as usize, self.width() as usize)).into_dyn();
ys.fill(144.0 / 255.0);
for (idx, x) in xs.iter().enumerate() {
let img = match self.task() {
YOLOTask::Classify => x.resize_exact(
self.width(),
self.height(),
image::imageops::FilterType::Triangle,
),
_ => {
let (w0, h0) = x.dimensions();
let w0 = w0 as f32;
let h0 = h0 as f32;
let (_, w_new, h_new) =
self.scale_wh(w0, h0, self.width() as f32, self.height() as f32); // f32 round
x.resize_exact(
w_new as u32,
h_new as u32,
if let YOLOTask::Segment = self.task() {
image::imageops::FilterType::CatmullRom
} else {
image::imageops::FilterType::Triangle
},
)
}
};
for (x, y, rgb) in img.pixels() {
let x = x as usize;
let y = y as usize;
let [r, g, b, _] = rgb.0;
ys[[idx, 0, y, x]] = (r as f32) / 255.0;
ys[[idx, 1, y, x]] = (g as f32) / 255.0;
ys[[idx, 2, y, x]] = (b as f32) / 255.0;
}
}
Ok(ys)
}
pub fn run(&mut self, xs: &Vec<DynamicImage>) -> Result<Vec<YOLOResult>> {
// pre-process
let t_pre = std::time::Instant::now();
let xs_ = self.preprocess(xs)?;
if self.profile {
println!("[Model Preprocess]: {:?}", t_pre.elapsed());
}
// run
let t_run = std::time::Instant::now();
let ys = self.engine.run(xs_, self.profile)?;
if self.profile {
println!("[Model Inference]: {:?}", t_run.elapsed());
}
// post-process
let t_post = std::time::Instant::now();
let ys = self.postprocess(ys, xs)?;
if self.profile {
println!("[Model Postprocess]: {:?}", t_post.elapsed());
}
// plot and save
if self.plot {
self.plot_and_save(&ys, xs, Some(&SKELETON));
}
Ok(ys)
}
pub fn postprocess(
&self,
xs: Vec<Array<f32, IxDyn>>,
xs0: &[DynamicImage],
) -> Result<Vec<YOLOResult>> {
if let YOLOTask::Classify = self.task() {
let mut ys = Vec::new();
let preds = &xs[0];
for batch in preds.axis_iter(Axis(0)) {
ys.push(YOLOResult::new(
Some(Embedding::new(batch.into_owned())),
None,
None,
None,
));
}
Ok(ys)
} else {
const CXYWH_OFFSET: usize = 4; // cxcywh
const KPT_STEP: usize = 3; // xyconf
let preds = &xs[0];
let protos = {
if xs.len() > 1 {
Some(&xs[1])
} else {
None
}
};
let mut ys = Vec::new();
for (idx, anchor) in preds.axis_iter(Axis(0)).enumerate() {
// [bs, 4 + nc + nm, anchors]
// input image
let width_original = xs0[idx].width() as f32;
let height_original = xs0[idx].height() as f32;
let ratio = (self.width() as f32 / width_original)
.min(self.height() as f32 / height_original);
// save each result
let mut data: Vec<(Bbox, Option<Vec<Point2>>, Option<Vec<f32>>)> = Vec::new();
for pred in anchor.axis_iter(Axis(1)) {
// split preds for different tasks
let bbox = pred.slice(s![0..CXYWH_OFFSET]);
let clss = pred.slice(s![CXYWH_OFFSET..CXYWH_OFFSET + self.nc() as usize]);
let kpts = {
if let YOLOTask::Pose = self.task() {
Some(pred.slice(s![pred.len() - KPT_STEP * self.nk() as usize..]))
} else {
None
}
};
let coefs = {
if let YOLOTask::Segment = self.task() {
Some(pred.slice(s![pred.len() - self.nm() as usize..]).to_vec())
} else {
None
}
};
// confidence and id
let (id, &confidence) = clss
.into_iter()
.enumerate()
.reduce(|max, x| if x.1 > max.1 { x } else { max })
.unwrap(); // definitely will not panic!
// confidence filter
if confidence < self.conf {
continue;
}
// bbox re-scale
let cx = bbox[0] / ratio;
let cy = bbox[1] / ratio;
let w = bbox[2] / ratio;
let h = bbox[3] / ratio;
let x = cx - w / 2.;
let y = cy - h / 2.;
let y_bbox = Bbox::new(
x.max(0.0f32).min(width_original),
y.max(0.0f32).min(height_original),
w,
h,
id,
confidence,
);
// kpts
let y_kpts = {
if let Some(kpts) = kpts {
let mut kpts_ = Vec::new();
// rescale
for i in 0..self.nk() as usize {
let kx = kpts[KPT_STEP * i] / ratio;
let ky = kpts[KPT_STEP * i + 1] / ratio;
let kconf = kpts[KPT_STEP * i + 2];
if kconf < self.kconf {
kpts_.push(Point2::default());
} else {
kpts_.push(Point2::new_with_conf(
kx.max(0.0f32).min(width_original),
ky.max(0.0f32).min(height_original),
kconf,
));
}
}
Some(kpts_)
} else {
None
}
};
// data merged
data.push((y_bbox, y_kpts, coefs));
}
// nms
non_max_suppression(&mut data, self.iou);
// decode
let mut y_bboxes: Vec<Bbox> = Vec::new();
let mut y_kpts: Vec<Vec<Point2>> = Vec::new();
let mut y_masks: Vec<Vec<u8>> = Vec::new();
for elem in data.into_iter() {
if let Some(kpts) = elem.1 {
y_kpts.push(kpts)
}
// decode masks
if let Some(coefs) = elem.2 {
let proto = protos.unwrap().slice(s![idx, .., .., ..]);
let (nm, nh, nw) = proto.dim();
// coefs * proto -> mask
let coefs = Array::from_shape_vec((1, nm), coefs)?; // (n, nm)
let proto = proto.to_owned().into_shape((nm, nh * nw))?; // (nm, nh*nw)
let mask = coefs.dot(&proto).into_shape((nh, nw, 1))?; // (nh, nw, n)
// build image from ndarray
let mask_im: ImageBuffer<image::Luma<_>, Vec<f32>> =
match ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec()) {
Some(image) => image,
None => panic!("can not create image from ndarray"),
};
let mut mask_im = image::DynamicImage::from(mask_im); // -> dyn
// rescale masks
let (_, w_mask, h_mask) =
self.scale_wh(width_original, height_original, nw as f32, nh as f32);
let mask_cropped = mask_im.crop(0, 0, w_mask as u32, h_mask as u32);
let mask_original = mask_cropped.resize_exact(
// resize_to_fill
width_original as u32,
height_original as u32,
match self.task() {
YOLOTask::Segment => image::imageops::FilterType::CatmullRom,
_ => image::imageops::FilterType::Triangle,
},
);
// crop-mask with bbox
let mut mask_original_cropped = mask_original.into_luma8();
for y in 0..height_original as usize {
for x in 0..width_original as usize {
if x < elem.0.xmin() as usize
|| x > elem.0.xmax() as usize
|| y < elem.0.ymin() as usize
|| y > elem.0.ymax() as usize
{
mask_original_cropped.put_pixel(
x as u32,
y as u32,
image::Luma([0u8]),
);
}
}
}
y_masks.push(mask_original_cropped.into_raw());
}
y_bboxes.push(elem.0);
}
// save each result
let y = YOLOResult {
probs: None,
bboxes: if !y_bboxes.is_empty() {
Some(y_bboxes)
} else {
None
},
keypoints: if !y_kpts.is_empty() {
Some(y_kpts)
} else {
None
},
masks: if !y_masks.is_empty() {
Some(y_masks)
} else {
None
},
};
ys.push(y);
}
Ok(ys)
}
}
pub fn plot_and_save(
&self,
ys: &[YOLOResult],
xs0: &[DynamicImage],
skeletons: Option<&[(usize, usize)]>,
) {
// check font then load
let font = check_font("Arial.ttf");
for (_idb, (img0, y)) in xs0.iter().zip(ys.iter()).enumerate() {
let mut img = img0.to_rgb8();
// draw for classifier
if let Some(probs) = y.probs() {
for (i, k) in probs.topk(5).iter().enumerate() {
let legend = format!("{} {:.2}%", self.names[k.0], k.1);
let scale = 32;
let legend_size = img.width().max(img.height()) / scale;
let x = img.width() / 20;
let y = img.height() / 20 + i as u32 * legend_size;
imageproc::drawing::draw_text_mut(
&mut img,
image::Rgb([0, 255, 0]),
x as i32,
y as i32,
rusttype::Scale::uniform(legend_size as f32 - 1.),
&font,
&legend,
);
}
}
// draw bboxes & keypoints
if let Some(bboxes) = y.bboxes() {
for (_idx, bbox) in bboxes.iter().enumerate() {
// rect
imageproc::drawing::draw_hollow_rect_mut(
&mut img,
imageproc::rect::Rect::at(bbox.xmin() as i32, bbox.ymin() as i32)
.of_size(bbox.width() as u32, bbox.height() as u32),
image::Rgb(self.color_palette[bbox.id()].into()),
);
// text
let legend = format!("{} {:.2}%", self.names[bbox.id()], bbox.confidence());
let scale = 40;
let legend_size = img.width().max(img.height()) / scale;
imageproc::drawing::draw_text_mut(
&mut img,
image::Rgb(self.color_palette[bbox.id()].into()),
bbox.xmin() as i32,
(bbox.ymin() - legend_size as f32) as i32,
rusttype::Scale::uniform(legend_size as f32 - 1.),
&font,
&legend,
);
}
}
// draw kpts
if let Some(keypoints) = y.keypoints() {
for kpts in keypoints.iter() {
for kpt in kpts.iter() {
// filter
if kpt.confidence() < self.kconf {
continue;
}
// draw point
imageproc::drawing::draw_filled_circle_mut(
&mut img,
(kpt.x() as i32, kpt.y() as i32),
2,
image::Rgb([0, 255, 0]),
);
}
// draw skeleton if has
if let Some(skeletons) = skeletons {
for &(idx1, idx2) in skeletons.iter() {
let kpt1 = &kpts[idx1];
let kpt2 = &kpts[idx2];
if kpt1.confidence() < self.kconf || kpt2.confidence() < self.kconf {
continue;
}
imageproc::drawing::draw_line_segment_mut(
&mut img,
(kpt1.x(), kpt1.y()),
(kpt2.x(), kpt2.y()),
image::Rgb([233, 14, 57]),
);
}
}
}
}
// draw mask
if let Some(masks) = y.masks() {
for (mask, _bbox) in masks.iter().zip(y.bboxes().unwrap().iter()) {
let mask_nd: ImageBuffer<image::Luma<_>, Vec<u8>> =
match ImageBuffer::from_vec(img.width(), img.height(), mask.to_vec()) {
Some(image) => image,
None => panic!("can not crate image from ndarray"),
};
for _x in 0..img.width() {
for _y in 0..img.height() {
let mask_p = imageproc::drawing::Canvas::get_pixel(&mask_nd, _x, _y);
if mask_p.0[0] > 0 {
let mut img_p = imageproc::drawing::Canvas::get_pixel(&img, _x, _y);
// img_p.0[2] = self.color_palette[bbox.id()].2 / 2;
// img_p.0[1] = self.color_palette[bbox.id()].1 / 2;
// img_p.0[0] = self.color_palette[bbox.id()].0 / 2;
img_p.0[2] /= 2;
img_p.0[1] = 255 - (255 - img_p.0[2]) / 2;
img_p.0[0] /= 2;
imageproc::drawing::Canvas::draw_pixel(&mut img, _x, _y, img_p)
}
}
}
}
}
// mkdir and save
let mut runs = PathBuf::from("runs");
if !runs.exists() {
std::fs::create_dir_all(&runs).unwrap();
}
runs.push(gen_time_string("-"));
let saveout = format!("{}.jpg", runs.to_str().unwrap());
let _ = img.save(saveout);
}
}
pub fn summary(&self) {
println!(
"\nSummary:\n\
> Task: {:?}{}\n\
> EP: {:?} {}\n\
> Dtype: {:?}\n\
> Batch: {} ({}), Height: {} ({}), Width: {} ({})\n\
> nc: {} nk: {}, nm: {}, conf: {}, kconf: {}, iou: {}\n\
",
self.task(),
match self.engine.author().zip(self.engine.version()) {
Some((author, ver)) => format!(" ({} {})", author, ver),
None => String::from(""),
},
self.engine.ep(),
if let OrtEP::Cpu = self.engine.ep() {
""
} else {
"(May still fall back to CPU)"
},
self.engine.dtype(),
self.batch(),
if self.engine.is_batch_dynamic() {
"Dynamic"
} else {
"Const"
},
self.height(),
if self.engine.is_height_dynamic() {
"Dynamic"
} else {
"Const"
},
self.width(),
if self.engine.is_width_dynamic() {
"Dynamic"
} else {
"Const"
},
self.nc(),
self.nk(),
self.nm(),
self.conf,
self.kconf,
self.iou,
);
}
pub fn engine(&self) -> &OrtBackend {
&self.engine
}
pub fn conf(&self) -> f32 {
self.conf
}
pub fn set_conf(&mut self, val: f32) {
self.conf = val;
}
pub fn conf_mut(&mut self) -> &mut f32 {
&mut self.conf
}
pub fn kconf(&self) -> f32 {
self.kconf
}
pub fn iou(&self) -> f32 {
self.iou
}
pub fn task(&self) -> &YOLOTask {
&self.task
}
pub fn batch(&self) -> u32 {
self.batch
}
pub fn width(&self) -> u32 {
self.width
}
pub fn height(&self) -> u32 {
self.height
}
pub fn nc(&self) -> u32 {
self.nc
}
pub fn nk(&self) -> u32 {
self.nk
}
pub fn nm(&self) -> u32 {
self.nm
}
pub fn names(&self) -> &Vec<String> {
&self.names
}
}

@ -0,0 +1,534 @@
use anyhow::Result;
use clap::ValueEnum;
use half::f16;
use ndarray::{Array, CowArray, IxDyn};
use ort::execution_providers::{CUDAExecutionProviderOptions, TensorRTExecutionProviderOptions};
use ort::tensor::TensorElementDataType;
use ort::{Environment, ExecutionProvider, Session, SessionBuilder, Value};
use regex::Regex;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
pub enum YOLOTask {
// YOLO tasks
Classify,
Detect,
Pose,
Segment,
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum OrtEP {
// ONNXRuntime execution provider
Cpu,
Cuda(u32),
Trt(u32),
}
#[derive(Debug)]
pub struct Batch {
pub opt: u32,
pub min: u32,
pub max: u32,
}
impl Default for Batch {
fn default() -> Self {
Self {
opt: 1,
min: 1,
max: 1,
}
}
}
#[derive(Debug, Default)]
pub struct OrtInputs {
// ONNX model inputs attrs
pub shapes: Vec<Vec<i32>>,
pub dtypes: Vec<TensorElementDataType>,
pub names: Vec<String>,
pub sizes: Vec<Vec<u32>>,
}
impl OrtInputs {
pub fn new(session: &Session) -> Self {
let mut shapes = Vec::new();
let mut dtypes = Vec::new();
let mut names = Vec::new();
for i in session.inputs.iter() {
let shape: Vec<i32> = i
.dimensions()
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
.collect();
shapes.push(shape);
dtypes.push(i.input_type);
names.push(i.name.clone());
}
Self {
shapes,
dtypes,
names,
..Default::default()
}
}
}
#[derive(Debug)]
pub struct OrtConfig {
// ORT config
pub f: String,
pub task: Option<YOLOTask>,
pub ep: OrtEP,
pub trt_fp16: bool,
pub batch: Batch,
pub image_size: (Option<u32>, Option<u32>),
}
#[derive(Debug)]
pub struct OrtBackend {
// ORT engine
session: Session,
task: YOLOTask,
ep: OrtEP,
batch: Batch,
inputs: OrtInputs,
}
impl OrtBackend {
pub fn build(args: OrtConfig) -> Result<Self> {
// build env & session
let env = Environment::builder()
.with_name("YOLOv8")
.with_log_level(ort::LoggingLevel::Verbose)
.build()?
.into_arc();
let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?;
// get inputs
let mut inputs = OrtInputs::new(&session);
// batch size
let mut batch = args.batch;
let batch = if inputs.shapes[0][0] == -1 {
batch
} else {
assert_eq!(
inputs.shapes[0][0] as u32, batch.opt,
"Expected batch size: {}, got {}. Try using `--batch {}`.",
inputs.shapes[0][0] as u32, batch.opt, inputs.shapes[0][0] as u32
);
batch.opt = inputs.shapes[0][0] as u32;
batch
};
// input size: height and width
let height = if inputs.shapes[0][2] == -1 {
match args.image_size.0 {
Some(height) => height,
None => panic!("Failed to get model height. Make it explicit with `--height`"),
}
} else {
inputs.shapes[0][2] as u32
};
let width = if inputs.shapes[0][3] == -1 {
match args.image_size.1 {
Some(width) => width,
None => panic!("Failed to get model width. Make it explicit with `--width`"),
}
} else {
inputs.shapes[0][3] as u32
};
inputs.sizes.push(vec![height, width]);
// build provider
let (ep, provider) = match args.ep {
OrtEP::Cuda(device_id) => Self::set_ep_cuda(device_id),
OrtEP::Trt(device_id) => Self::set_ep_trt(device_id, args.trt_fp16, &batch, &inputs),
_ => (OrtEP::Cpu, ExecutionProvider::CPU(Default::default())),
};
// build session again with the new provider
let session = SessionBuilder::new(&env)?
// .with_optimization_level(ort::GraphOptimizationLevel::Level3)?
.with_execution_providers([provider])?
.with_model_from_file(args.f)?;
// task: using given one or guessing
let task = match args.task {
Some(task) => task,
None => match session.metadata() {
Err(_) => panic!("No metadata found. Try making it explicit by `--task`"),
Ok(metadata) => match metadata.custom("task") {
Err(_) => panic!("Can not get custom value. Try making it explicit by `--task`"),
Ok(value) => match value {
None => panic!("No correspoing value of `task` found in metadata. Make it explicit by `--task`"),
Some(task) => match task.as_str() {
"classify" => YOLOTask::Classify,
"detect" => YOLOTask::Detect,
"pose" => YOLOTask::Pose,
"segment" => YOLOTask::Segment,
x => todo!("{:?} is not supported for now!", x),
},
},
},
},
};
Ok(Self {
session,
task,
ep,
batch,
inputs,
})
}
pub fn fetch_inputs_from_session(
session: &Session,
) -> (Vec<Vec<i32>>, Vec<TensorElementDataType>, Vec<String>) {
// get inputs attrs from ONNX model
let mut shapes = Vec::new();
let mut dtypes = Vec::new();
let mut names = Vec::new();
for i in session.inputs.iter() {
let shape: Vec<i32> = i
.dimensions()
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
.collect();
shapes.push(shape);
dtypes.push(i.input_type);
names.push(i.name.clone());
}
(shapes, dtypes, names)
}
pub fn set_ep_cuda(device_id: u32) -> (OrtEP, ExecutionProvider) {
// set CUDA
if ExecutionProvider::CUDA(Default::default()).is_available() {
(
OrtEP::Cuda(device_id),
ExecutionProvider::CUDA(CUDAExecutionProviderOptions {
device_id,
..Default::default()
}),
)
} else {
println!("> CUDA is not available! Using CPU.");
(OrtEP::Cpu, ExecutionProvider::CPU(Default::default()))
}
}
pub fn set_ep_trt(
device_id: u32,
fp16: bool,
batch: &Batch,
inputs: &OrtInputs,
) -> (OrtEP, ExecutionProvider) {
// set TensorRT
if ExecutionProvider::TensorRT(Default::default()).is_available() {
let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]);
// dtype match checking
if inputs.dtypes[0] == TensorElementDataType::Float16 && !fp16 {
panic!(
"Dtype mismatch! Expected: Float32, got: {:?}. You should use `--fp16`",
inputs.dtypes[0]
);
}
// dynamic shape: input_tensor_1:dim_1xdim_2x...,input_tensor_2:dim_3xdim_4x...,...
let mut opt_string = String::new();
let mut min_string = String::new();
let mut max_string = String::new();
for name in inputs.names.iter() {
let s_opt = format!("{}:{}x3x{}x{},", name, batch.opt, height, width);
let s_min = format!("{}:{}x3x{}x{},", name, batch.min, height, width);
let s_max = format!("{}:{}x3x{}x{},", name, batch.max, height, width);
opt_string.push_str(s_opt.as_str());
min_string.push_str(s_min.as_str());
max_string.push_str(s_max.as_str());
}
let _ = opt_string.pop();
let _ = min_string.pop();
let _ = max_string.pop();
(
OrtEP::Trt(device_id),
ExecutionProvider::TensorRT(TensorRTExecutionProviderOptions {
device_id,
fp16_enable: fp16,
timing_cache_enable: true,
profile_min_shapes: min_string,
profile_max_shapes: max_string,
profile_opt_shapes: opt_string,
..Default::default()
}),
)
} else {
println!("> TensorRT is not available! Try using CUDA...");
Self::set_ep_cuda(device_id)
}
}
pub fn fetch_from_metadata(&self, key: &str) -> Option<String> {
// fetch value from onnx model file by key
match self.session.metadata() {
Err(_) => None,
Ok(metadata) => match metadata.custom(key) {
Err(_) => None,
Ok(value) => value,
},
}
}
pub fn run(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
// ORT inference
match self.dtype() {
TensorElementDataType::Float16 => self.run_fp16(xs, profile),
TensorElementDataType::Float32 => self.run_fp32(xs, profile),
_ => todo!(),
}
}
pub fn run_fp16(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
// f32->f16
let t = std::time::Instant::now();
let xs = xs.mapv(f16::from_f32);
if profile {
println!("[ORT f32->f16]: {:?}", t.elapsed());
}
// h2d
let t = std::time::Instant::now();
let xs = CowArray::from(xs);
let xs = vec![Value::from_array(self.session.allocator(), &xs)?];
if profile {
println!("[ORT H2D]: {:?}", t.elapsed());
}
// run
let t = std::time::Instant::now();
let ys = self.session.run(xs)?;
if profile {
println!("[ORT Inference]: {:?}", t.elapsed());
}
// d2h
Ok(ys
.iter()
.map(|x| {
// d2h
let t = std::time::Instant::now();
let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
if profile {
println!("[ORT D2H]: {:?}", t.elapsed());
}
// f16->f32
let t_ = std::time::Instant::now();
let x = x.mapv(f16::to_f32);
if profile {
println!("[ORT f16->f32]: {:?}", t_.elapsed());
}
x
})
.collect::<Vec<Array<_, _>>>())
}
pub fn run_fp32(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
// h2d
let t = std::time::Instant::now();
let xs = CowArray::from(xs);
let xs = vec![Value::from_array(self.session.allocator(), &xs)?];
if profile {
println!("[ORT H2D]: {:?}", t.elapsed());
}
// run
let t = std::time::Instant::now();
let ys = self.session.run(xs)?;
if profile {
println!("[ORT Inference]: {:?}", t.elapsed());
}
// d2h
Ok(ys
.iter()
.map(|x| {
let t = std::time::Instant::now();
let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
if profile {
println!("[ORT D2H]: {:?}", t.elapsed());
}
x
})
.collect::<Vec<Array<_, _>>>())
}
pub fn output_shapes(&self) -> Vec<Vec<i32>> {
let mut shapes = Vec::new();
for o in &self.session.outputs {
let shape: Vec<_> = o
.dimensions()
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
.collect();
shapes.push(shape);
}
shapes
}
pub fn output_dtypes(&self) -> Vec<TensorElementDataType> {
let mut dtypes = Vec::new();
self.session
.outputs
.iter()
.for_each(|x| dtypes.push(x.output_type));
dtypes
}
pub fn input_shapes(&self) -> &Vec<Vec<i32>> {
&self.inputs.shapes
}
pub fn input_names(&self) -> &Vec<String> {
&self.inputs.names
}
pub fn input_dtypes(&self) -> &Vec<TensorElementDataType> {
&self.inputs.dtypes
}
pub fn dtype(&self) -> TensorElementDataType {
self.input_dtypes()[0]
}
pub fn height(&self) -> u32 {
self.inputs.sizes[0][0]
}
pub fn width(&self) -> u32 {
self.inputs.sizes[0][1]
}
pub fn is_height_dynamic(&self) -> bool {
self.input_shapes()[0][2] == -1
}
pub fn is_width_dynamic(&self) -> bool {
self.input_shapes()[0][3] == -1
}
pub fn batch(&self) -> u32 {
self.batch.opt
}
pub fn is_batch_dynamic(&self) -> bool {
self.input_shapes()[0][0] == -1
}
pub fn ep(&self) -> &OrtEP {
&self.ep
}
pub fn task(&self) -> YOLOTask {
self.task.clone()
}
pub fn names(&self) -> Option<Vec<String>> {
// class names, metadata parsing
// String format: `{0: 'person', 1: 'bicycle', 2: 'sports ball', ..., 27: "yellow_lady's_slipper"}`
match self.fetch_from_metadata("names") {
Some(names) => {
let re = Regex::new(r#"(['"])([-()\w '"]+)(['"])"#).unwrap();
let mut names_ = vec![];
for (_, [_, name, _]) in re.captures_iter(&names).map(|x| x.extract()) {
names_.push(name.to_string());
}
Some(names_)
}
None => None,
}
}
pub fn nk(&self) -> Option<u32> {
// num_keypoints, metadata parsing: String `nk` in onnx model: `[17, 3]`
match self.fetch_from_metadata("kpt_shape") {
None => None,
Some(kpt_string) => {
let re = Regex::new(r"([0-9]+), ([0-9]+)").unwrap();
let caps = re.captures(&kpt_string).unwrap();
Some(caps.get(1).unwrap().as_str().parse::<u32>().unwrap())
}
}
}
pub fn nc(&self) -> Option<u32> {
// num_classes
match self.names() {
// by names
Some(names) => Some(names.len() as u32),
None => match self.task() {
// by task calculation
YOLOTask::Classify => Some(self.output_shapes()[0][1] as u32),
YOLOTask::Detect => {
if self.output_shapes()[0][1] == -1 {
None
} else {
// cxywhclss
Some(self.output_shapes()[0][1] as u32 - 4)
}
}
YOLOTask::Pose => {
match self.nk() {
None => None,
Some(nk) => {
if self.output_shapes()[0][1] == -1 {
None
} else {
// cxywhclss3*kpt
Some(self.output_shapes()[0][1] as u32 - 4 - 3 * nk)
}
}
}
}
YOLOTask::Segment => {
if self.output_shapes()[0][1] == -1 {
None
} else {
// cxywhclssnm
Some((self.output_shapes()[0][1] - self.output_shapes()[1][1]) as u32 - 4)
}
}
},
}
}
pub fn nm(&self) -> Option<u32> {
// num_masks
match self.task() {
YOLOTask::Segment => Some(self.output_shapes()[1][1] as u32),
_ => None,
}
}
pub fn na(&self) -> Option<u32> {
// num_anchors
match self.task() {
YOLOTask::Segment | YOLOTask::Detect | YOLOTask::Pose => {
if self.output_shapes()[0][2] == -1 {
None
} else {
Some(self.output_shapes()[0][2] as u32)
}
}
_ => None,
}
}
pub fn author(&self) -> Option<String> {
self.fetch_from_metadata("author")
}
pub fn version(&self) -> Option<String> {
self.fetch_from_metadata("version")
}
}

@ -0,0 +1,235 @@
use ndarray::{Array, Axis, IxDyn};
#[derive(Clone, PartialEq, Default)]
pub struct YOLOResult {
// YOLO tasks results of an image
pub probs: Option<Embedding>,
pub bboxes: Option<Vec<Bbox>>,
pub keypoints: Option<Vec<Vec<Point2>>>,
pub masks: Option<Vec<Vec<u8>>>,
}
impl std::fmt::Debug for YOLOResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("YOLOResult")
.field(
"Probs(top5)",
&format_args!("{:?}", self.probs().map(|probs| probs.topk(5))),
)
.field("Bboxes", &self.bboxes)
.field("Keypoints", &self.keypoints)
.field(
"Masks",
&format_args!("{:?}", self.masks().map(|masks| masks.len())),
)
.finish()
}
}
impl YOLOResult {
pub fn new(
probs: Option<Embedding>,
bboxes: Option<Vec<Bbox>>,
keypoints: Option<Vec<Vec<Point2>>>,
masks: Option<Vec<Vec<u8>>>,
) -> Self {
Self {
probs,
bboxes,
keypoints,
masks,
}
}
pub fn probs(&self) -> Option<&Embedding> {
self.probs.as_ref()
}
pub fn keypoints(&self) -> Option<&Vec<Vec<Point2>>> {
self.keypoints.as_ref()
}
pub fn masks(&self) -> Option<&Vec<Vec<u8>>> {
self.masks.as_ref()
}
pub fn bboxes(&self) -> Option<&Vec<Bbox>> {
self.bboxes.as_ref()
}
pub fn bboxes_mut(&mut self) -> Option<&mut Vec<Bbox>> {
self.bboxes.as_mut()
}
}
#[derive(Debug, PartialEq, Clone, Default)]
pub struct Point2 {
// A point2d with x, y, conf
x: f32,
y: f32,
confidence: f32,
}
impl Point2 {
pub fn new_with_conf(x: f32, y: f32, confidence: f32) -> Self {
Self { x, y, confidence }
}
pub fn new(x: f32, y: f32) -> Self {
Self {
x,
y,
..Default::default()
}
}
pub fn x(&self) -> f32 {
self.x
}
pub fn y(&self) -> f32 {
self.y
}
pub fn confidence(&self) -> f32 {
self.confidence
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct Embedding {
// An float32 n-dims tensor
data: Array<f32, IxDyn>,
}
impl Embedding {
pub fn new(data: Array<f32, IxDyn>) -> Self {
Self { data }
}
pub fn data(&self) -> &Array<f32, IxDyn> {
&self.data
}
pub fn topk(&self, k: usize) -> Vec<(usize, f32)> {
let mut probs = self
.data
.iter()
.enumerate()
.map(|(a, b)| (a, *b))
.collect::<Vec<_>>();
probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
let mut topk = Vec::new();
for &(id, confidence) in probs.iter().take(k) {
topk.push((id, confidence));
}
topk
}
pub fn norm(&self) -> Array<f32, IxDyn> {
let std_ = self.data.mapv(|x| x * x).sum_axis(Axis(0)).mapv(f32::sqrt);
self.data.clone() / std_
}
pub fn top1(&self) -> (usize, f32) {
self.topk(1)[0]
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct Bbox {
// a bounding box around an object
xmin: f32,
ymin: f32,
width: f32,
height: f32,
id: usize,
confidence: f32,
}
impl Bbox {
pub fn new_from_xywh(xmin: f32, ymin: f32, width: f32, height: f32) -> Self {
Self {
xmin,
ymin,
width,
height,
..Default::default()
}
}
pub fn new(xmin: f32, ymin: f32, width: f32, height: f32, id: usize, confidence: f32) -> Self {
Self {
xmin,
ymin,
width,
height,
id,
confidence,
}
}
pub fn width(&self) -> f32 {
self.width
}
pub fn height(&self) -> f32 {
self.height
}
pub fn xmin(&self) -> f32 {
self.xmin
}
pub fn ymin(&self) -> f32 {
self.ymin
}
pub fn xmax(&self) -> f32 {
self.xmin + self.width
}
pub fn ymax(&self) -> f32 {
self.ymin + self.height
}
pub fn tl(&self) -> Point2 {
Point2::new(self.xmin, self.ymin)
}
pub fn br(&self) -> Point2 {
Point2::new(self.xmax(), self.ymax())
}
pub fn cxcy(&self) -> Point2 {
Point2::new(self.xmin + self.width / 2., self.ymin + self.height / 2.)
}
pub fn id(&self) -> usize {
self.id
}
pub fn confidence(&self) -> f32 {
self.confidence
}
pub fn area(&self) -> f32 {
self.width * self.height
}
pub fn intersection_area(&self, another: &Bbox) -> f32 {
let l = self.xmin.max(another.xmin);
let r = (self.xmin + self.width).min(another.xmin + another.width);
let t = self.ymin.max(another.ymin);
let b = (self.ymin + self.height).min(another.ymin + another.height);
(r - l + 1.).max(0.) * (b - t + 1.).max(0.)
}
pub fn union(&self, another: &Bbox) -> f32 {
self.area() + another.area() - self.intersection_area(another)
}
pub fn iou(&self, another: &Bbox) -> f32 {
self.intersection_area(another) / self.union(another)
}
}
Loading…
Cancel
Save