diff --git a/docs/en/guides/triton-inference-server.md b/docs/en/guides/triton-inference-server.md index 7395ccef11..09f7516b11 100644 --- a/docs/en/guides/triton-inference-server.md +++ b/docs/en/guides/triton-inference-server.md @@ -80,6 +80,28 @@ The Triton Model Repository is a storage location where Triton can access and lo # Create config file (triton_model_path / "config.pbtxt").touch() + + # (Optional) Enable TensorRT for GPU inference + # First run will be slow due to TensorRT engine conversion + import json + + data = { + "optimization": { + "execution_accelerators": { + "gpu_execution_accelerator": [ + { + "name": "tensorrt", + "parameters": {"key": "precision_mode", "value": "FP16"}, + "parameters": {"key": "max_workspace_size_bytes", "value": "3221225472"}, + "parameters": {"key": "trt_engine_cache_enable", "value": "1"}, + } + ] + } + } + } + + with open(triton_model_path / "config.pbtxt", "w") as f: + json.dump(data, f, indent=4) ``` ## Running Triton Inference Server @@ -94,7 +116,7 @@ import time from tritonclient.http import InferenceServerClient # Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver -tag = "nvcr.io/nvidia/tritonserver:23.09-py3" # 6.4 GB +tag = "nvcr.io/nvidia/tritonserver:24.09-py3" # 8.57 GB # Pull the image subprocess.call(f"docker pull {tag}", shell=True) @@ -187,7 +209,7 @@ Setting up [Ultralytics YOLO11](https://docs.ultralytics.com/models/yolov8/) wit from tritonclient.http import InferenceServerClient # Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver - tag = "nvcr.io/nvidia/tritonserver:23.09-py3" + tag = "nvcr.io/nvidia/tritonserver:24.09-py3" subprocess.call(f"docker pull {tag}", shell=True) diff --git a/examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml b/examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml index 8ac747e7e3..fcf1fb7974 100644 --- a/examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml +++ b/examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml @@ -9,11 +9,11 @@ edition = "2021" [dependencies] clap = { version = "4.2.4", features = ["derive"] } -image = { version = "0.24.7", default-features = false, features = ["jpeg", "png", "webp-encoder"] } -imageproc = { version = "0.23.0", default-features = false } -ndarray = { version = "0.15.6" } -ort = { version = "1.16.3", default-features = false, features = ["load-dynamic", "copy-dylibs", "half"] } -rusttype = { version = "0.9", default-features = false } +image = { version = "0.25.2"} +imageproc = { version = "0.25.0"} +ndarray = { version = "0.16" } +ort = { version = "2.0.0-rc.5", features = ["cuda", "tensorrt"]} +rusttype = { version = "0.9.3" } anyhow = { version = "1.0.75" } regex = { version = "1.5.4" } rand = { version = "0.8.5" } @@ -21,3 +21,4 @@ chrono = { version = "0.4.30" } half = { version = "2.3.1" } dirs = { version = "5.0.1" } ureq = { version = "2.9.1" } +ab_glyph = "0.2.29" diff --git a/examples/YOLOv8-ONNXRuntime-Rust/README.md b/examples/YOLOv8-ONNXRuntime-Rust/README.md index 48a3017ce8..9121c7dac7 100644 --- a/examples/YOLOv8-ONNXRuntime-Rust/README.md +++ b/examples/YOLOv8-ONNXRuntime-Rust/README.md @@ -5,7 +5,7 @@ This repository provides a Rust demo for performing YOLOv8 tasks like `Classific ## Recently Updated - Add YOLOv8-OBB demo -- Update ONNXRuntime to 1.17.x +- Update ONNXRuntime to 1.19.x Newly updated YOLOv8 example code is located in this repository (https://github.com/jamjamjon/usls/tree/main/examples/yolo) diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/cli.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/cli.rs index 2ba0dd49ec..b5bc05a585 100644 --- a/examples/YOLOv8-ONNXRuntime-Rust/src/cli.rs +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/cli.rs @@ -15,7 +15,7 @@ pub struct Args { /// device id #[arg(long, default_value_t = 0)] - pub device_id: u32, + pub device_id: i32, /// using TensorRT EP #[arg(long)] diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/lib.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/lib.rs index 1af7f7c5e1..849801ee47 100644 --- a/examples/YOLOv8-ONNXRuntime-Rust/src/lib.rs +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/lib.rs @@ -117,3 +117,45 @@ pub fn check_font(font: &str) -> rusttype::Font<'static> { let buffer = std::fs::read(font_path).unwrap(); rusttype::Font::try_from_vec(buffer).unwrap() } + + +use ab_glyph::FontArc; +pub fn load_font() -> FontArc{ + use std::path::Path; + let font_path = Path::new("./font/Arial.ttf"); + match font_path.try_exists() { + Ok(true) => { + let buffer = std::fs::read(font_path).unwrap(); + FontArc::try_from_vec(buffer).unwrap() + }, + Ok(false) => { + std::fs::create_dir_all("./font").unwrap(); + println!("Downloading font..."); + let source_url = "https://ultralytics.com/assets/Arial.ttf"; + let resp = ureq::get(source_url) + .timeout(std::time::Duration::from_secs(500)) + .call() + .unwrap_or_else(|err| panic!("> Failed to download font: {source_url}: {err:?}")); + + // read to buffer + let mut buffer = vec![]; + let total_size = resp + .header("Content-Length") + .and_then(|s| s.parse::().ok()) + .unwrap(); + let _reader = resp + .into_reader() + .take(total_size) + .read_to_end(&mut buffer) + .unwrap(); + // save + let mut fd = std::fs::File::create(font_path).unwrap(); + fd.write_all(&buffer).unwrap(); + println!("Font saved at: {:?}", font_path.display()); + FontArc::try_from_vec(buffer).unwrap() + }, + Err(e) => { + panic!("Failed to load font {}", e); + }, + } +} \ No newline at end of file diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/main.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/main.rs index 8dd1567990..fd3845ced0 100644 --- a/examples/YOLOv8-ONNXRuntime-Rust/src/main.rs +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/main.rs @@ -6,7 +6,7 @@ fn main() -> Result<(), Box> { let args = Args::parse(); // 1. load image - let x = image::io::Reader::open(&args.source)? + let x = image::ImageReader::open(&args.source)? .with_guessed_format()? .decode()?; diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/model.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/model.rs index 1c0e5e494d..e0c35f6c26 100644 --- a/examples/YOLOv8-ONNXRuntime-Rust/src/model.rs +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/model.rs @@ -1,5 +1,6 @@ #![allow(clippy::type_complexity)] +use ab_glyph::FontArc; use anyhow::Result; use image::{DynamicImage, GenericImageView, ImageBuffer}; use ndarray::{s, Array, Axis, IxDyn}; @@ -7,7 +8,7 @@ use rand::{thread_rng, Rng}; use std::path::PathBuf; use crate::{ - check_font, gen_time_string, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend, + load_font, gen_time_string, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend, OrtConfig, OrtEP, Point2, YOLOResult, YOLOTask, SKELETON, }; @@ -36,10 +37,10 @@ impl YOLOv8 { let ep = if config.trt { OrtEP::Trt(config.device_id) } else if config.cuda { - OrtEP::Cuda(config.device_id) + OrtEP::CUDA(config.device_id) } else { - OrtEP::Cpu - }; + OrtEP::CPU + }; // batch let batch = Batch { @@ -330,12 +331,19 @@ impl YOLOv8 { // coefs * proto -> mask let coefs = Array::from_shape_vec((1, nm), coefs)?; // (n, nm) - let proto = proto.to_owned().into_shape((nm, nh * nw))?; // (nm, nh*nw) - let mask = coefs.dot(&proto).into_shape((nh, nw, 1))?; // (nh, nw, n) + + let proto = proto.to_owned(); + let proto = proto.to_shape((nm, nh * nw))?; // (nm, nh*nw) + let mask = coefs.dot(&proto); // (nh, nw, n) + let mask = mask.to_shape((nh, nw, 1))?; // build image from ndarray let mask_im: ImageBuffer, Vec> = - match ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec()) { + match ImageBuffer::from_raw( + nw as u32, + nh as u32, + mask.to_owned().into_raw_vec_and_offset().0, + ) { Some(image) => image, None => panic!("can not create image from ndarray"), }; @@ -410,7 +418,7 @@ impl YOLOv8 { skeletons: Option<&[(usize, usize)]>, ) { // check font then load - let font = check_font("Arial.ttf"); + let font: FontArc = load_font(); for (_idb, (img0, y)) in xs0.iter().zip(ys.iter()).enumerate() { let mut img = img0.to_rgb8(); @@ -422,12 +430,13 @@ impl YOLOv8 { let legend_size = img.width().max(img.height()) / scale; let x = img.width() / 20; let y = img.height() / 20 + i as u32 * legend_size; + imageproc::drawing::draw_text_mut( &mut img, image::Rgb([0, 255, 0]), x as i32, y as i32, - rusttype::Scale::uniform(legend_size as f32 - 1.), + legend_size as f32, &font, &legend, ); @@ -454,7 +463,7 @@ impl YOLOv8 { image::Rgb(self.color_palette[bbox.id()].into()), bbox.xmin() as i32, (bbox.ymin() - legend_size as f32) as i32, - rusttype::Scale::uniform(legend_size as f32 - 1.), + legend_size as f32, &font, &legend, ); @@ -551,7 +560,7 @@ impl YOLOv8 { None => String::from(""), }, self.engine.ep(), - if let OrtEP::Cpu = self.engine.ep() { + if let OrtEP::CPU = self.engine.ep() { "" } else { "(May still fall back to CPU)" diff --git a/examples/YOLOv8-ONNXRuntime-Rust/src/ort_backend.rs b/examples/YOLOv8-ONNXRuntime-Rust/src/ort_backend.rs index 857baaebae..d88208dead 100644 --- a/examples/YOLOv8-ONNXRuntime-Rust/src/ort_backend.rs +++ b/examples/YOLOv8-ONNXRuntime-Rust/src/ort_backend.rs @@ -2,11 +2,13 @@ use anyhow::Result; use clap::ValueEnum; use half::f16; use ndarray::{Array, CowArray, IxDyn}; -use ort::execution_providers::{CUDAExecutionProviderOptions, TensorRTExecutionProviderOptions}; -use ort::tensor::TensorElementDataType; -use ort::{Environment, ExecutionProvider, Session, SessionBuilder, Value}; +use ort::{ + CPUExecutionProvider, CUDAExecutionProvider, ExecutionProvider, ExecutionProviderDispatch, + TensorRTExecutionProvider, +}; +use ort::{Session, SessionBuilder}; +use ort::{TensorElementType, ValueType}; use regex::Regex; - #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] pub enum YOLOTask { // YOLO tasks @@ -19,9 +21,9 @@ pub enum YOLOTask { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub enum OrtEP { // ONNXRuntime execution provider - Cpu, - Cuda(u32), - Trt(u32), + CPU, + CUDA(i32), + Trt(i32), } #[derive(Debug)] @@ -44,8 +46,9 @@ impl Default for Batch { #[derive(Debug, Default)] pub struct OrtInputs { // ONNX model inputs attrs - pub shapes: Vec>, - pub dtypes: Vec, + pub shapes: Vec>, + //pub dtypes: Vec, + pub dtypes: Vec, pub names: Vec, pub sizes: Vec>, } @@ -56,12 +59,19 @@ impl OrtInputs { let mut dtypes = Vec::new(); let mut names = Vec::new(); for i in session.inputs.iter() { - let shape: Vec = i + /* let shape: Vec = i .dimensions() .map(|x| if let Some(x) = x { x as i32 } else { -1i32 }) .collect(); - shapes.push(shape); - dtypes.push(i.input_type); + shapes.push(shape); */ + if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type { + dtypes.push(ty.clone()); + let shape = dimensions.clone(); + shapes.push(shape); + } else { + panic!("不支持的数据格式, {} - {}", file!(), line!()); + } + //dtypes.push(i.input_type); names.push(i.name.clone()); } Self { @@ -97,12 +107,14 @@ pub struct OrtBackend { impl OrtBackend { pub fn build(args: OrtConfig) -> Result { // build env & session - let env = Environment::builder() - .with_name("YOLOv8") - .with_log_level(ort::LoggingLevel::Verbose) - .build()? - .into_arc(); - let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?; + // in version 2.x environment is removed + /* let env = ort::EnvironmentBuilder + ::with_name("YOLOv8") + .build()? + .into_arc(); */ + let sessionbuilder = SessionBuilder::new()?; + let session = sessionbuilder.commit_from_file(&args.f)?; + //let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?; // get inputs let mut inputs = OrtInputs::new(&session); @@ -142,16 +154,19 @@ impl OrtBackend { // build provider let (ep, provider) = match args.ep { - OrtEP::Cuda(device_id) => Self::set_ep_cuda(device_id), + OrtEP::CUDA(device_id) => Self::set_ep_cuda(device_id), OrtEP::Trt(device_id) => Self::set_ep_trt(device_id, args.trt_fp16, &batch, &inputs), - _ => (OrtEP::Cpu, ExecutionProvider::CPU(Default::default())), + _ => ( + OrtEP::CPU, + ExecutionProviderDispatch::from(CPUExecutionProvider::default()), + ), }; // build session again with the new provider - let session = SessionBuilder::new(&env)? + let session = SessionBuilder::new()? // .with_optimization_level(ort::GraphOptimizationLevel::Level3)? .with_execution_providers([provider])? - .with_model_from_file(args.f)?; + .commit_from_file(args.f)?; // task: using given one or guessing let task = match args.task { @@ -185,57 +200,58 @@ impl OrtBackend { pub fn fetch_inputs_from_session( session: &Session, - ) -> (Vec>, Vec, Vec) { + ) -> (Vec>, Vec, Vec) { // get inputs attrs from ONNX model let mut shapes = Vec::new(); let mut dtypes = Vec::new(); let mut names = Vec::new(); for i in session.inputs.iter() { - let shape: Vec = i - .dimensions() - .map(|x| if let Some(x) = x { x as i32 } else { -1i32 }) - .collect(); - shapes.push(shape); - dtypes.push(i.input_type); + if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type { + dtypes.push(ty.clone()); + let shape = dimensions.clone(); + shapes.push(shape); + } else { + panic!("不支持的数据格式, {} - {}", file!(), line!()); + } names.push(i.name.clone()); } (shapes, dtypes, names) } - pub fn set_ep_cuda(device_id: u32) -> (OrtEP, ExecutionProvider) { - // set CUDA - if ExecutionProvider::CUDA(Default::default()).is_available() { + pub fn set_ep_cuda(device_id: i32) -> (OrtEP, ExecutionProviderDispatch) { + let cuda_provider = CUDAExecutionProvider::default().with_device_id(device_id); + if let Ok(true) = cuda_provider.is_available() { ( - OrtEP::Cuda(device_id), - ExecutionProvider::CUDA(CUDAExecutionProviderOptions { - device_id, - ..Default::default() - }), + OrtEP::CUDA(device_id), + ExecutionProviderDispatch::from(cuda_provider), //PlantForm::CUDA(cuda_provider) ) } else { println!("> CUDA is not available! Using CPU."); - (OrtEP::Cpu, ExecutionProvider::CPU(Default::default())) + ( + OrtEP::CPU, + ExecutionProviderDispatch::from(CPUExecutionProvider::default()), //PlantForm::CPU(CPUExecutionProvider::default()) + ) } } pub fn set_ep_trt( - device_id: u32, + device_id: i32, fp16: bool, batch: &Batch, inputs: &OrtInputs, - ) -> (OrtEP, ExecutionProvider) { + ) -> (OrtEP, ExecutionProviderDispatch) { // set TensorRT - if ExecutionProvider::TensorRT(Default::default()).is_available() { - let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]); + let trt_provider = TensorRTExecutionProvider::default().with_device_id(device_id); - // dtype match checking - if inputs.dtypes[0] == TensorElementDataType::Float16 && !fp16 { + //trt_provider. + if let Ok(true) = trt_provider.is_available() { + let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]); + if inputs.dtypes[0] == TensorElementType::Float16 && !fp16 { panic!( "Dtype mismatch! Expected: Float32, got: {:?}. You should use `--fp16`", inputs.dtypes[0] ); } - // dynamic shape: input_tensor_1:dim_1xdim_2x...,input_tensor_2:dim_3xdim_4x...,... let mut opt_string = String::new(); let mut min_string = String::new(); @@ -251,17 +267,16 @@ impl OrtBackend { let _ = opt_string.pop(); let _ = min_string.pop(); let _ = max_string.pop(); + + let trt_provider = trt_provider + .with_profile_opt_shapes(opt_string) + .with_profile_min_shapes(min_string) + .with_profile_max_shapes(max_string) + .with_fp16(fp16) + .with_timing_cache(true); ( OrtEP::Trt(device_id), - ExecutionProvider::TensorRT(TensorRTExecutionProviderOptions { - device_id, - fp16_enable: fp16, - timing_cache_enable: true, - profile_min_shapes: min_string, - profile_max_shapes: max_string, - profile_opt_shapes: opt_string, - ..Default::default() - }), + ExecutionProviderDispatch::from(trt_provider), ) } else { println!("> TensorRT is not available! Try using CUDA..."); @@ -283,8 +298,8 @@ impl OrtBackend { pub fn run(&self, xs: Array, profile: bool) -> Result>> { // ORT inference match self.dtype() { - TensorElementDataType::Float16 => self.run_fp16(xs, profile), - TensorElementDataType::Float32 => self.run_fp32(xs, profile), + TensorElementType::Float16 => self.run_fp16(xs, profile), + TensorElementType::Float32 => self.run_fp32(xs, profile), _ => todo!(), } } @@ -300,14 +315,13 @@ impl OrtBackend { // h2d let t = std::time::Instant::now(); let xs = CowArray::from(xs); - let xs = vec![Value::from_array(self.session.allocator(), &xs)?]; if profile { println!("[ORT H2D]: {:?}", t.elapsed()); } // run let t = std::time::Instant::now(); - let ys = self.session.run(xs)?; + let ys = self.session.run(ort::inputs![xs.view()]?)?; if profile { println!("[ORT Inference]: {:?}", t.elapsed()); } @@ -315,21 +329,22 @@ impl OrtBackend { // d2h Ok(ys .iter() - .map(|x| { + .map(|(_k, v)| { // d2h let t = std::time::Instant::now(); - let x = x.try_extract::<_>().unwrap().view().clone().into_owned(); + let v = v.try_extract_tensor().unwrap(); + //let v = v.try_extract::<_>().unwrap().view().clone().into_owned(); if profile { println!("[ORT D2H]: {:?}", t.elapsed()); } // f16->f32 let t_ = std::time::Instant::now(); - let x = x.mapv(f16::to_f32); + let v = v.mapv(f16::to_f32); if profile { println!("[ORT f16->f32]: {:?}", t_.elapsed()); } - x + v }) .collect::>>()) } @@ -338,14 +353,13 @@ impl OrtBackend { // h2d let t = std::time::Instant::now(); let xs = CowArray::from(xs); - let xs = vec![Value::from_array(self.session.allocator(), &xs)?]; if profile { println!("[ORT H2D]: {:?}", t.elapsed()); } // run let t = std::time::Instant::now(); - let ys = self.session.run(xs)?; + let ys = self.session.run(ort::inputs![xs.view()]?)?; if profile { println!("[ORT Inference]: {:?}", t.elapsed()); } @@ -353,39 +367,44 @@ impl OrtBackend { // d2h Ok(ys .iter() - .map(|x| { + .map(|(_k, v)| { let t = std::time::Instant::now(); - let x = x.try_extract::<_>().unwrap().view().clone().into_owned(); + let v = v.try_extract_tensor::().unwrap().into_owned(); + //let x = x.try_extract::<_>().unwrap().view().clone().into_owned(); if profile { println!("[ORT D2H]: {:?}", t.elapsed()); } - x + v }) .collect::>>()) } - pub fn output_shapes(&self) -> Vec> { + pub fn output_shapes(&self) -> Vec> { let mut shapes = Vec::new(); - for o in &self.session.outputs { - let shape: Vec<_> = o - .dimensions() - .map(|x| if let Some(x) = x { x as i32 } else { -1i32 }) - .collect(); - shapes.push(shape); + for output in &self.session.outputs { + if let ValueType::Tensor { ty: _, dimensions } = &output.output_type { + let shape = dimensions.clone(); + shapes.push(shape); + } else { + panic!("not support data format, {} - {}", file!(), line!()); + } } shapes } - pub fn output_dtypes(&self) -> Vec { + pub fn output_dtypes(&self) -> Vec { let mut dtypes = Vec::new(); - self.session - .outputs - .iter() - .for_each(|x| dtypes.push(x.output_type)); + for output in &self.session.outputs { + if let ValueType::Tensor { ty, dimensions: _ } = &output.output_type { + dtypes.push(ty.clone()); + } else { + panic!("not support data format, {} - {}", file!(), line!()); + } + } dtypes } - pub fn input_shapes(&self) -> &Vec> { + pub fn input_shapes(&self) -> &Vec> { &self.inputs.shapes } @@ -393,11 +412,11 @@ impl OrtBackend { &self.inputs.names } - pub fn input_dtypes(&self) -> &Vec { + pub fn input_dtypes(&self) -> &Vec { &self.inputs.dtypes } - pub fn dtype(&self) -> TensorElementDataType { + pub fn dtype(&self) -> TensorElementType { self.input_dtypes()[0] } diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index d3a2a8711a..4f641eeecd 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -890,8 +890,10 @@ class Exporter: tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file if self.args.data: f.mkdir() - images = [batch["img"].permute(0, 2, 3, 1) for batch in self.get_int8_calibration_dataloader(prefix)] - images = torch.cat(images, 0).float() + images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)] + images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute( + 0, 2, 3, 1 + ) np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]] diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index 9e6d38b49f..b9312fefdb 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -189,10 +189,32 @@ class AutoBackend(nn.Module): check_requirements("numpy==1.23.5") import onnxruntime - providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"] + providers = onnxruntime.get_available_providers() + if not cuda and "CUDAExecutionProvider" in providers: + providers.remove("CUDAExecutionProvider") + elif cuda and "CUDAExecutionProvider" not in providers: + LOGGER.warning("WARNING ⚠️ Failed to start ONNX Runtime session with CUDA. Falling back to CPU...") + device = torch.device("cpu") + cuda = False + LOGGER.info(f"Preferring ONNX Runtime {providers[0]}") session = onnxruntime.InferenceSession(w, providers=providers) output_names = [x.name for x in session.get_outputs()] metadata = session.get_modelmeta().custom_metadata_map + dynamic = isinstance(session.get_outputs()[0].shape[0], str) + if not dynamic: + io = session.io_binding() + bindings = [] + for output in session.get_outputs(): + y_tensor = torch.empty(output.shape, dtype=torch.float16 if fp16 else torch.float32).to(device) + io.bind_output( + name=output.name, + device_type=device.type, + device_id=device.index if cuda else 0, + element_type=np.float16 if fp16 else np.float32, + shape=tuple(y_tensor.shape), + buffer_ptr=y_tensor.data_ptr(), + ) + bindings.append(y_tensor) # OpenVINO elif xml: @@ -477,8 +499,22 @@ class AutoBackend(nn.Module): # ONNX Runtime elif self.onnx: - im = im.cpu().numpy() # torch to numpy - y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) + if self.dynamic: + im = im.cpu().numpy() # torch to numpy + y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) + else: + if not self.cuda: + im = im.cpu() + self.io.bind_input( + name="images", + device_type=im.device.type, + device_id=im.device.index if im.device.type == "cuda" else 0, + element_type=np.float16 if self.fp16 else np.float32, + shape=tuple(im.shape), + buffer_ptr=im.data_ptr(), + ) + self.session.run_with_iobinding(self.io) + y = self.bindings # OpenVINO elif self.xml: