Merge branch 'main' into no-grad-test

no-grad-test
Francesco Mattioli 1 month ago committed by GitHub
commit ea486b4ccc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 26
      docs/en/guides/triton-inference-server.md
  2. 11
      examples/YOLOv8-ONNXRuntime-Rust/Cargo.toml
  3. 2
      examples/YOLOv8-ONNXRuntime-Rust/README.md
  4. 2
      examples/YOLOv8-ONNXRuntime-Rust/src/cli.rs
  5. 42
      examples/YOLOv8-ONNXRuntime-Rust/src/lib.rs
  6. 2
      examples/YOLOv8-ONNXRuntime-Rust/src/main.rs
  7. 31
      examples/YOLOv8-ONNXRuntime-Rust/src/model.rs
  8. 183
      examples/YOLOv8-ONNXRuntime-Rust/src/ort_backend.rs
  9. 6
      ultralytics/engine/exporter.py
  10. 42
      ultralytics/nn/autobackend.py

@ -80,6 +80,28 @@ The Triton Model Repository is a storage location where Triton can access and lo
# Create config file
(triton_model_path / "config.pbtxt").touch()
# (Optional) Enable TensorRT for GPU inference
# First run will be slow due to TensorRT engine conversion
import json
data = {
"optimization": {
"execution_accelerators": {
"gpu_execution_accelerator": [
{
"name": "tensorrt",
"parameters": {"key": "precision_mode", "value": "FP16"},
"parameters": {"key": "max_workspace_size_bytes", "value": "3221225472"},
"parameters": {"key": "trt_engine_cache_enable", "value": "1"},
}
]
}
}
}
with open(triton_model_path / "config.pbtxt", "w") as f:
json.dump(data, f, indent=4)
```
## Running Triton Inference Server
@ -94,7 +116,7 @@ import time
from tritonclient.http import InferenceServerClient
# Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
tag = "nvcr.io/nvidia/tritonserver:23.09-py3" # 6.4 GB
tag = "nvcr.io/nvidia/tritonserver:24.09-py3" # 8.57 GB
# Pull the image
subprocess.call(f"docker pull {tag}", shell=True)
@ -187,7 +209,7 @@ Setting up [Ultralytics YOLO11](https://docs.ultralytics.com/models/yolov8/) wit
from tritonclient.http import InferenceServerClient
# Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
tag = "nvcr.io/nvidia/tritonserver:23.09-py3"
tag = "nvcr.io/nvidia/tritonserver:24.09-py3"
subprocess.call(f"docker pull {tag}", shell=True)

@ -9,11 +9,11 @@ edition = "2021"
[dependencies]
clap = { version = "4.2.4", features = ["derive"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png", "webp-encoder"] }
imageproc = { version = "0.23.0", default-features = false }
ndarray = { version = "0.15.6" }
ort = { version = "1.16.3", default-features = false, features = ["load-dynamic", "copy-dylibs", "half"] }
rusttype = { version = "0.9", default-features = false }
image = { version = "0.25.2"}
imageproc = { version = "0.25.0"}
ndarray = { version = "0.16" }
ort = { version = "2.0.0-rc.5", features = ["cuda", "tensorrt"]}
rusttype = { version = "0.9.3" }
anyhow = { version = "1.0.75" }
regex = { version = "1.5.4" }
rand = { version = "0.8.5" }
@ -21,3 +21,4 @@ chrono = { version = "0.4.30" }
half = { version = "2.3.1" }
dirs = { version = "5.0.1" }
ureq = { version = "2.9.1" }
ab_glyph = "0.2.29"

@ -5,7 +5,7 @@ This repository provides a Rust demo for performing YOLOv8 tasks like `Classific
## Recently Updated
- Add YOLOv8-OBB demo
- Update ONNXRuntime to 1.17.x
- Update ONNXRuntime to 1.19.x
Newly updated YOLOv8 example code is located in this repository (https://github.com/jamjamjon/usls/tree/main/examples/yolo)

@ -15,7 +15,7 @@ pub struct Args {
/// device id
#[arg(long, default_value_t = 0)]
pub device_id: u32,
pub device_id: i32,
/// using TensorRT EP
#[arg(long)]

@ -117,3 +117,45 @@ pub fn check_font(font: &str) -> rusttype::Font<'static> {
let buffer = std::fs::read(font_path).unwrap();
rusttype::Font::try_from_vec(buffer).unwrap()
}
use ab_glyph::FontArc;
pub fn load_font() -> FontArc{
use std::path::Path;
let font_path = Path::new("./font/Arial.ttf");
match font_path.try_exists() {
Ok(true) => {
let buffer = std::fs::read(font_path).unwrap();
FontArc::try_from_vec(buffer).unwrap()
},
Ok(false) => {
std::fs::create_dir_all("./font").unwrap();
println!("Downloading font...");
let source_url = "https://ultralytics.com/assets/Arial.ttf";
let resp = ureq::get(source_url)
.timeout(std::time::Duration::from_secs(500))
.call()
.unwrap_or_else(|err| panic!("> Failed to download font: {source_url}: {err:?}"));
// read to buffer
let mut buffer = vec![];
let total_size = resp
.header("Content-Length")
.and_then(|s| s.parse::<u64>().ok())
.unwrap();
let _reader = resp
.into_reader()
.take(total_size)
.read_to_end(&mut buffer)
.unwrap();
// save
let mut fd = std::fs::File::create(font_path).unwrap();
fd.write_all(&buffer).unwrap();
println!("Font saved at: {:?}", font_path.display());
FontArc::try_from_vec(buffer).unwrap()
},
Err(e) => {
panic!("Failed to load font {}", e);
},
}
}

@ -6,7 +6,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
// 1. load image
let x = image::io::Reader::open(&args.source)?
let x = image::ImageReader::open(&args.source)?
.with_guessed_format()?
.decode()?;

@ -1,5 +1,6 @@
#![allow(clippy::type_complexity)]
use ab_glyph::FontArc;
use anyhow::Result;
use image::{DynamicImage, GenericImageView, ImageBuffer};
use ndarray::{s, Array, Axis, IxDyn};
@ -7,7 +8,7 @@ use rand::{thread_rng, Rng};
use std::path::PathBuf;
use crate::{
check_font, gen_time_string, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend,
load_font, gen_time_string, non_max_suppression, Args, Batch, Bbox, Embedding, OrtBackend,
OrtConfig, OrtEP, Point2, YOLOResult, YOLOTask, SKELETON,
};
@ -36,10 +37,10 @@ impl YOLOv8 {
let ep = if config.trt {
OrtEP::Trt(config.device_id)
} else if config.cuda {
OrtEP::Cuda(config.device_id)
OrtEP::CUDA(config.device_id)
} else {
OrtEP::Cpu
};
OrtEP::CPU
};
// batch
let batch = Batch {
@ -330,12 +331,19 @@ impl YOLOv8 {
// coefs * proto -> mask
let coefs = Array::from_shape_vec((1, nm), coefs)?; // (n, nm)
let proto = proto.to_owned().into_shape((nm, nh * nw))?; // (nm, nh*nw)
let mask = coefs.dot(&proto).into_shape((nh, nw, 1))?; // (nh, nw, n)
let proto = proto.to_owned();
let proto = proto.to_shape((nm, nh * nw))?; // (nm, nh*nw)
let mask = coefs.dot(&proto); // (nh, nw, n)
let mask = mask.to_shape((nh, nw, 1))?;
// build image from ndarray
let mask_im: ImageBuffer<image::Luma<_>, Vec<f32>> =
match ImageBuffer::from_raw(nw as u32, nh as u32, mask.into_raw_vec()) {
match ImageBuffer::from_raw(
nw as u32,
nh as u32,
mask.to_owned().into_raw_vec_and_offset().0,
) {
Some(image) => image,
None => panic!("can not create image from ndarray"),
};
@ -410,7 +418,7 @@ impl YOLOv8 {
skeletons: Option<&[(usize, usize)]>,
) {
// check font then load
let font = check_font("Arial.ttf");
let font: FontArc = load_font();
for (_idb, (img0, y)) in xs0.iter().zip(ys.iter()).enumerate() {
let mut img = img0.to_rgb8();
@ -422,12 +430,13 @@ impl YOLOv8 {
let legend_size = img.width().max(img.height()) / scale;
let x = img.width() / 20;
let y = img.height() / 20 + i as u32 * legend_size;
imageproc::drawing::draw_text_mut(
&mut img,
image::Rgb([0, 255, 0]),
x as i32,
y as i32,
rusttype::Scale::uniform(legend_size as f32 - 1.),
legend_size as f32,
&font,
&legend,
);
@ -454,7 +463,7 @@ impl YOLOv8 {
image::Rgb(self.color_palette[bbox.id()].into()),
bbox.xmin() as i32,
(bbox.ymin() - legend_size as f32) as i32,
rusttype::Scale::uniform(legend_size as f32 - 1.),
legend_size as f32,
&font,
&legend,
);
@ -551,7 +560,7 @@ impl YOLOv8 {
None => String::from(""),
},
self.engine.ep(),
if let OrtEP::Cpu = self.engine.ep() {
if let OrtEP::CPU = self.engine.ep() {
""
} else {
"(May still fall back to CPU)"

@ -2,11 +2,13 @@ use anyhow::Result;
use clap::ValueEnum;
use half::f16;
use ndarray::{Array, CowArray, IxDyn};
use ort::execution_providers::{CUDAExecutionProviderOptions, TensorRTExecutionProviderOptions};
use ort::tensor::TensorElementDataType;
use ort::{Environment, ExecutionProvider, Session, SessionBuilder, Value};
use ort::{
CPUExecutionProvider, CUDAExecutionProvider, ExecutionProvider, ExecutionProviderDispatch,
TensorRTExecutionProvider,
};
use ort::{Session, SessionBuilder};
use ort::{TensorElementType, ValueType};
use regex::Regex;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
pub enum YOLOTask {
// YOLO tasks
@ -19,9 +21,9 @@ pub enum YOLOTask {
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum OrtEP {
// ONNXRuntime execution provider
Cpu,
Cuda(u32),
Trt(u32),
CPU,
CUDA(i32),
Trt(i32),
}
#[derive(Debug)]
@ -44,8 +46,9 @@ impl Default for Batch {
#[derive(Debug, Default)]
pub struct OrtInputs {
// ONNX model inputs attrs
pub shapes: Vec<Vec<i32>>,
pub dtypes: Vec<TensorElementDataType>,
pub shapes: Vec<Vec<i64>>,
//pub dtypes: Vec<TensorElementDataType>,
pub dtypes: Vec<TensorElementType>,
pub names: Vec<String>,
pub sizes: Vec<Vec<u32>>,
}
@ -56,12 +59,19 @@ impl OrtInputs {
let mut dtypes = Vec::new();
let mut names = Vec::new();
for i in session.inputs.iter() {
let shape: Vec<i32> = i
/* let shape: Vec<i32> = i
.dimensions()
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
.collect();
shapes.push(shape);
dtypes.push(i.input_type);
shapes.push(shape); */
if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type {
dtypes.push(ty.clone());
let shape = dimensions.clone();
shapes.push(shape);
} else {
panic!("不支持的数据格式, {} - {}", file!(), line!());
}
//dtypes.push(i.input_type);
names.push(i.name.clone());
}
Self {
@ -97,12 +107,14 @@ pub struct OrtBackend {
impl OrtBackend {
pub fn build(args: OrtConfig) -> Result<Self> {
// build env & session
let env = Environment::builder()
.with_name("YOLOv8")
.with_log_level(ort::LoggingLevel::Verbose)
.build()?
.into_arc();
let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?;
// in version 2.x environment is removed
/* let env = ort::EnvironmentBuilder
::with_name("YOLOv8")
.build()?
.into_arc(); */
let sessionbuilder = SessionBuilder::new()?;
let session = sessionbuilder.commit_from_file(&args.f)?;
//let session = SessionBuilder::new(&env)?.with_model_from_file(&args.f)?;
// get inputs
let mut inputs = OrtInputs::new(&session);
@ -142,16 +154,19 @@ impl OrtBackend {
// build provider
let (ep, provider) = match args.ep {
OrtEP::Cuda(device_id) => Self::set_ep_cuda(device_id),
OrtEP::CUDA(device_id) => Self::set_ep_cuda(device_id),
OrtEP::Trt(device_id) => Self::set_ep_trt(device_id, args.trt_fp16, &batch, &inputs),
_ => (OrtEP::Cpu, ExecutionProvider::CPU(Default::default())),
_ => (
OrtEP::CPU,
ExecutionProviderDispatch::from(CPUExecutionProvider::default()),
),
};
// build session again with the new provider
let session = SessionBuilder::new(&env)?
let session = SessionBuilder::new()?
// .with_optimization_level(ort::GraphOptimizationLevel::Level3)?
.with_execution_providers([provider])?
.with_model_from_file(args.f)?;
.commit_from_file(args.f)?;
// task: using given one or guessing
let task = match args.task {
@ -185,57 +200,58 @@ impl OrtBackend {
pub fn fetch_inputs_from_session(
session: &Session,
) -> (Vec<Vec<i32>>, Vec<TensorElementDataType>, Vec<String>) {
) -> (Vec<Vec<i64>>, Vec<TensorElementType>, Vec<String>) {
// get inputs attrs from ONNX model
let mut shapes = Vec::new();
let mut dtypes = Vec::new();
let mut names = Vec::new();
for i in session.inputs.iter() {
let shape: Vec<i32> = i
.dimensions()
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
.collect();
shapes.push(shape);
dtypes.push(i.input_type);
if let ort::ValueType::Tensor { ty, dimensions } = &i.input_type {
dtypes.push(ty.clone());
let shape = dimensions.clone();
shapes.push(shape);
} else {
panic!("不支持的数据格式, {} - {}", file!(), line!());
}
names.push(i.name.clone());
}
(shapes, dtypes, names)
}
pub fn set_ep_cuda(device_id: u32) -> (OrtEP, ExecutionProvider) {
// set CUDA
if ExecutionProvider::CUDA(Default::default()).is_available() {
pub fn set_ep_cuda(device_id: i32) -> (OrtEP, ExecutionProviderDispatch) {
let cuda_provider = CUDAExecutionProvider::default().with_device_id(device_id);
if let Ok(true) = cuda_provider.is_available() {
(
OrtEP::Cuda(device_id),
ExecutionProvider::CUDA(CUDAExecutionProviderOptions {
device_id,
..Default::default()
}),
OrtEP::CUDA(device_id),
ExecutionProviderDispatch::from(cuda_provider), //PlantForm::CUDA(cuda_provider)
)
} else {
println!("> CUDA is not available! Using CPU.");
(OrtEP::Cpu, ExecutionProvider::CPU(Default::default()))
(
OrtEP::CPU,
ExecutionProviderDispatch::from(CPUExecutionProvider::default()), //PlantForm::CPU(CPUExecutionProvider::default())
)
}
}
pub fn set_ep_trt(
device_id: u32,
device_id: i32,
fp16: bool,
batch: &Batch,
inputs: &OrtInputs,
) -> (OrtEP, ExecutionProvider) {
) -> (OrtEP, ExecutionProviderDispatch) {
// set TensorRT
if ExecutionProvider::TensorRT(Default::default()).is_available() {
let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]);
let trt_provider = TensorRTExecutionProvider::default().with_device_id(device_id);
// dtype match checking
if inputs.dtypes[0] == TensorElementDataType::Float16 && !fp16 {
//trt_provider.
if let Ok(true) = trt_provider.is_available() {
let (height, width) = (inputs.sizes[0][0], inputs.sizes[0][1]);
if inputs.dtypes[0] == TensorElementType::Float16 && !fp16 {
panic!(
"Dtype mismatch! Expected: Float32, got: {:?}. You should use `--fp16`",
inputs.dtypes[0]
);
}
// dynamic shape: input_tensor_1:dim_1xdim_2x...,input_tensor_2:dim_3xdim_4x...,...
let mut opt_string = String::new();
let mut min_string = String::new();
@ -251,17 +267,16 @@ impl OrtBackend {
let _ = opt_string.pop();
let _ = min_string.pop();
let _ = max_string.pop();
let trt_provider = trt_provider
.with_profile_opt_shapes(opt_string)
.with_profile_min_shapes(min_string)
.with_profile_max_shapes(max_string)
.with_fp16(fp16)
.with_timing_cache(true);
(
OrtEP::Trt(device_id),
ExecutionProvider::TensorRT(TensorRTExecutionProviderOptions {
device_id,
fp16_enable: fp16,
timing_cache_enable: true,
profile_min_shapes: min_string,
profile_max_shapes: max_string,
profile_opt_shapes: opt_string,
..Default::default()
}),
ExecutionProviderDispatch::from(trt_provider),
)
} else {
println!("> TensorRT is not available! Try using CUDA...");
@ -283,8 +298,8 @@ impl OrtBackend {
pub fn run(&self, xs: Array<f32, IxDyn>, profile: bool) -> Result<Vec<Array<f32, IxDyn>>> {
// ORT inference
match self.dtype() {
TensorElementDataType::Float16 => self.run_fp16(xs, profile),
TensorElementDataType::Float32 => self.run_fp32(xs, profile),
TensorElementType::Float16 => self.run_fp16(xs, profile),
TensorElementType::Float32 => self.run_fp32(xs, profile),
_ => todo!(),
}
}
@ -300,14 +315,13 @@ impl OrtBackend {
// h2d
let t = std::time::Instant::now();
let xs = CowArray::from(xs);
let xs = vec![Value::from_array(self.session.allocator(), &xs)?];
if profile {
println!("[ORT H2D]: {:?}", t.elapsed());
}
// run
let t = std::time::Instant::now();
let ys = self.session.run(xs)?;
let ys = self.session.run(ort::inputs![xs.view()]?)?;
if profile {
println!("[ORT Inference]: {:?}", t.elapsed());
}
@ -315,21 +329,22 @@ impl OrtBackend {
// d2h
Ok(ys
.iter()
.map(|x| {
.map(|(_k, v)| {
// d2h
let t = std::time::Instant::now();
let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
let v = v.try_extract_tensor().unwrap();
//let v = v.try_extract::<_>().unwrap().view().clone().into_owned();
if profile {
println!("[ORT D2H]: {:?}", t.elapsed());
}
// f16->f32
let t_ = std::time::Instant::now();
let x = x.mapv(f16::to_f32);
let v = v.mapv(f16::to_f32);
if profile {
println!("[ORT f16->f32]: {:?}", t_.elapsed());
}
x
v
})
.collect::<Vec<Array<_, _>>>())
}
@ -338,14 +353,13 @@ impl OrtBackend {
// h2d
let t = std::time::Instant::now();
let xs = CowArray::from(xs);
let xs = vec![Value::from_array(self.session.allocator(), &xs)?];
if profile {
println!("[ORT H2D]: {:?}", t.elapsed());
}
// run
let t = std::time::Instant::now();
let ys = self.session.run(xs)?;
let ys = self.session.run(ort::inputs![xs.view()]?)?;
if profile {
println!("[ORT Inference]: {:?}", t.elapsed());
}
@ -353,39 +367,44 @@ impl OrtBackend {
// d2h
Ok(ys
.iter()
.map(|x| {
.map(|(_k, v)| {
let t = std::time::Instant::now();
let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
let v = v.try_extract_tensor::<f32>().unwrap().into_owned();
//let x = x.try_extract::<_>().unwrap().view().clone().into_owned();
if profile {
println!("[ORT D2H]: {:?}", t.elapsed());
}
x
v
})
.collect::<Vec<Array<_, _>>>())
}
pub fn output_shapes(&self) -> Vec<Vec<i32>> {
pub fn output_shapes(&self) -> Vec<Vec<i64>> {
let mut shapes = Vec::new();
for o in &self.session.outputs {
let shape: Vec<_> = o
.dimensions()
.map(|x| if let Some(x) = x { x as i32 } else { -1i32 })
.collect();
shapes.push(shape);
for output in &self.session.outputs {
if let ValueType::Tensor { ty: _, dimensions } = &output.output_type {
let shape = dimensions.clone();
shapes.push(shape);
} else {
panic!("not support data format, {} - {}", file!(), line!());
}
}
shapes
}
pub fn output_dtypes(&self) -> Vec<TensorElementDataType> {
pub fn output_dtypes(&self) -> Vec<TensorElementType> {
let mut dtypes = Vec::new();
self.session
.outputs
.iter()
.for_each(|x| dtypes.push(x.output_type));
for output in &self.session.outputs {
if let ValueType::Tensor { ty, dimensions: _ } = &output.output_type {
dtypes.push(ty.clone());
} else {
panic!("not support data format, {} - {}", file!(), line!());
}
}
dtypes
}
pub fn input_shapes(&self) -> &Vec<Vec<i32>> {
pub fn input_shapes(&self) -> &Vec<Vec<i64>> {
&self.inputs.shapes
}
@ -393,11 +412,11 @@ impl OrtBackend {
&self.inputs.names
}
pub fn input_dtypes(&self) -> &Vec<TensorElementDataType> {
pub fn input_dtypes(&self) -> &Vec<TensorElementType> {
&self.inputs.dtypes
}
pub fn dtype(&self) -> TensorElementDataType {
pub fn dtype(&self) -> TensorElementType {
self.input_dtypes()[0]
}

@ -890,8 +890,10 @@ class Exporter:
tmp_file = f / "tmp_tflite_int8_calibration_images.npy" # int8 calibration images file
if self.args.data:
f.mkdir()
images = [batch["img"].permute(0, 2, 3, 1) for batch in self.get_int8_calibration_dataloader(prefix)]
images = torch.cat(images, 0).float()
images = [batch["img"] for batch in self.get_int8_calibration_dataloader(prefix)]
images = torch.nn.functional.interpolate(torch.cat(images, 0).float(), size=self.imgsz).permute(
0, 2, 3, 1
)
np.save(str(tmp_file), images.numpy().astype(np.float32)) # BHWC
np_data = [["images", tmp_file, [[[[0, 0, 0]]]], [[[[255, 255, 255]]]]]]

@ -189,10 +189,32 @@ class AutoBackend(nn.Module):
check_requirements("numpy==1.23.5")
import onnxruntime
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"]
providers = onnxruntime.get_available_providers()
if not cuda and "CUDAExecutionProvider" in providers:
providers.remove("CUDAExecutionProvider")
elif cuda and "CUDAExecutionProvider" not in providers:
LOGGER.warning("WARNING ⚠ Failed to start ONNX Runtime session with CUDA. Falling back to CPU...")
device = torch.device("cpu")
cuda = False
LOGGER.info(f"Preferring ONNX Runtime {providers[0]}")
session = onnxruntime.InferenceSession(w, providers=providers)
output_names = [x.name for x in session.get_outputs()]
metadata = session.get_modelmeta().custom_metadata_map
dynamic = isinstance(session.get_outputs()[0].shape[0], str)
if not dynamic:
io = session.io_binding()
bindings = []
for output in session.get_outputs():
y_tensor = torch.empty(output.shape, dtype=torch.float16 if fp16 else torch.float32).to(device)
io.bind_output(
name=output.name,
device_type=device.type,
device_id=device.index if cuda else 0,
element_type=np.float16 if fp16 else np.float32,
shape=tuple(y_tensor.shape),
buffer_ptr=y_tensor.data_ptr(),
)
bindings.append(y_tensor)
# OpenVINO
elif xml:
@ -477,8 +499,22 @@ class AutoBackend(nn.Module):
# ONNX Runtime
elif self.onnx:
im = im.cpu().numpy() # torch to numpy
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
if self.dynamic:
im = im.cpu().numpy() # torch to numpy
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
else:
if not self.cuda:
im = im.cpu()
self.io.bind_input(
name="images",
device_type=im.device.type,
device_id=im.device.index if im.device.type == "cuda" else 0,
element_type=np.float16 if self.fp16 else np.float32,
shape=tuple(im.shape),
buffer_ptr=im.data_ptr(),
)
self.session.run_with_iobinding(self.io)
y = self.bindings
# OpenVINO
elif self.xml:

Loading…
Cancel
Save