|
|
|
@ -135,8 +135,8 @@ class AutoBackend(nn.Module): |
|
|
|
|
if not (pt or triton or nn_module): |
|
|
|
|
w = attempt_download_asset(w) |
|
|
|
|
|
|
|
|
|
# Load model |
|
|
|
|
if nn_module: # in-memory PyTorch model |
|
|
|
|
# In-memory PyTorch model |
|
|
|
|
if nn_module: |
|
|
|
|
model = weights.to(device) |
|
|
|
|
model = model.fuse(verbose=verbose) if fuse else model |
|
|
|
|
if hasattr(model, "kpt_shape"): |
|
|
|
@ -146,7 +146,9 @@ class AutoBackend(nn.Module): |
|
|
|
|
model.half() if fp16 else model.float() |
|
|
|
|
self.model = model # explicitly assign for to(), cpu(), cuda(), half() |
|
|
|
|
pt = True |
|
|
|
|
elif pt: # PyTorch |
|
|
|
|
|
|
|
|
|
# PyTorch |
|
|
|
|
elif pt: |
|
|
|
|
from ultralytics.nn.tasks import attempt_load_weights |
|
|
|
|
|
|
|
|
|
model = attempt_load_weights( |
|
|
|
@ -158,18 +160,24 @@ class AutoBackend(nn.Module): |
|
|
|
|
names = model.module.names if hasattr(model, "module") else model.names # get class names |
|
|
|
|
model.half() if fp16 else model.float() |
|
|
|
|
self.model = model # explicitly assign for to(), cpu(), cuda(), half() |
|
|
|
|
elif jit: # TorchScript |
|
|
|
|
|
|
|
|
|
# TorchScript |
|
|
|
|
elif jit: |
|
|
|
|
LOGGER.info(f"Loading {w} for TorchScript inference...") |
|
|
|
|
extra_files = {"config.txt": ""} # model metadata |
|
|
|
|
model = torch.jit.load(w, _extra_files=extra_files, map_location=device) |
|
|
|
|
model.half() if fp16 else model.float() |
|
|
|
|
if extra_files["config.txt"]: # load metadata dict |
|
|
|
|
metadata = json.loads(extra_files["config.txt"], object_hook=lambda x: dict(x.items())) |
|
|
|
|
elif dnn: # ONNX OpenCV DNN |
|
|
|
|
|
|
|
|
|
# ONNX OpenCV DNN |
|
|
|
|
elif dnn: |
|
|
|
|
LOGGER.info(f"Loading {w} for ONNX OpenCV DNN inference...") |
|
|
|
|
check_requirements("opencv-python>=4.5.4") |
|
|
|
|
net = cv2.dnn.readNetFromONNX(w) |
|
|
|
|
elif onnx: # ONNX Runtime |
|
|
|
|
|
|
|
|
|
# ONNX Runtime |
|
|
|
|
elif onnx: |
|
|
|
|
LOGGER.info(f"Loading {w} for ONNX Runtime inference...") |
|
|
|
|
check_requirements(("onnx", "onnxruntime-gpu" if cuda else "onnxruntime")) |
|
|
|
|
import onnxruntime |
|
|
|
@ -177,11 +185,13 @@ class AutoBackend(nn.Module): |
|
|
|
|
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if cuda else ["CPUExecutionProvider"] |
|
|
|
|
session = onnxruntime.InferenceSession(w, providers=providers) |
|
|
|
|
output_names = [x.name for x in session.get_outputs()] |
|
|
|
|
metadata = session.get_modelmeta().custom_metadata_map # metadata |
|
|
|
|
elif xml: # OpenVINO |
|
|
|
|
metadata = session.get_modelmeta().custom_metadata_map |
|
|
|
|
|
|
|
|
|
# OpenVINO |
|
|
|
|
elif xml: |
|
|
|
|
LOGGER.info(f"Loading {w} for OpenVINO inference...") |
|
|
|
|
check_requirements("openvino>=2023.3") # requires openvino: https://pypi.org/project/openvino-dev/ |
|
|
|
|
import openvino as ov # noqa |
|
|
|
|
check_requirements("openvino>=2023.3") |
|
|
|
|
import openvino as ov |
|
|
|
|
|
|
|
|
|
core = ov.Core() |
|
|
|
|
w = Path(w) |
|
|
|
@ -193,9 +203,18 @@ class AutoBackend(nn.Module): |
|
|
|
|
batch_dim = ov.get_batch(ov_model) |
|
|
|
|
if batch_dim.is_static: |
|
|
|
|
batch_size = batch_dim.get_length() |
|
|
|
|
ov_compiled_model = core.compile_model(ov_model, device_name="AUTO") # AUTO selects best available device |
|
|
|
|
|
|
|
|
|
inference_mode = "LATENCY" # either 'LATENCY', 'THROUGHPUT' (not recommended), or 'CUMULATIVE_THROUGHPUT' |
|
|
|
|
ov_compiled_model = core.compile_model( |
|
|
|
|
ov_model, |
|
|
|
|
device_name="AUTO", # AUTO selects best available device, do not modify |
|
|
|
|
config={"PERFORMANCE_HINT": inference_mode}, |
|
|
|
|
) |
|
|
|
|
input_name = ov_compiled_model.input().get_any_name() |
|
|
|
|
metadata = w.parent / "metadata.yaml" |
|
|
|
|
elif engine: # TensorRT |
|
|
|
|
|
|
|
|
|
# TensorRT |
|
|
|
|
elif engine: |
|
|
|
|
LOGGER.info(f"Loading {w} for TensorRT inference...") |
|
|
|
|
try: |
|
|
|
|
import tensorrt as trt # noqa https://developer.nvidia.com/nvidia-tensorrt-download |
|
|
|
@ -234,20 +253,26 @@ class AutoBackend(nn.Module): |
|
|
|
|
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr())) |
|
|
|
|
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items()) |
|
|
|
|
batch_size = bindings["images"].shape[0] # if dynamic, this is instead max batch size |
|
|
|
|
elif coreml: # CoreML |
|
|
|
|
|
|
|
|
|
# CoreML |
|
|
|
|
elif coreml: |
|
|
|
|
LOGGER.info(f"Loading {w} for CoreML inference...") |
|
|
|
|
import coremltools as ct |
|
|
|
|
|
|
|
|
|
model = ct.models.MLModel(w) |
|
|
|
|
metadata = dict(model.user_defined_metadata) |
|
|
|
|
elif saved_model: # TF SavedModel |
|
|
|
|
|
|
|
|
|
# TF SavedModel |
|
|
|
|
elif saved_model: |
|
|
|
|
LOGGER.info(f"Loading {w} for TensorFlow SavedModel inference...") |
|
|
|
|
import tensorflow as tf |
|
|
|
|
|
|
|
|
|
keras = False # assume TF1 saved_model |
|
|
|
|
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w) |
|
|
|
|
metadata = Path(w) / "metadata.yaml" |
|
|
|
|
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt |
|
|
|
|
|
|
|
|
|
# TF GraphDef |
|
|
|
|
elif pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt |
|
|
|
|
LOGGER.info(f"Loading {w} for TensorFlow GraphDef inference...") |
|
|
|
|
import tensorflow as tf |
|
|
|
|
|
|
|
|
@ -263,6 +288,8 @@ class AutoBackend(nn.Module): |
|
|
|
|
with open(w, "rb") as f: |
|
|
|
|
gd.ParseFromString(f.read()) |
|
|
|
|
frozen_func = wrap_frozen_graph(gd, inputs="x:0", outputs=gd_outputs(gd)) |
|
|
|
|
|
|
|
|
|
# TFLite or TFLite Edge TPU |
|
|
|
|
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python |
|
|
|
|
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu |
|
|
|
|
from tflite_runtime.interpreter import Interpreter, load_delegate |
|
|
|
@ -287,9 +314,13 @@ class AutoBackend(nn.Module): |
|
|
|
|
with zipfile.ZipFile(w, "r") as model: |
|
|
|
|
meta_file = model.namelist()[0] |
|
|
|
|
metadata = ast.literal_eval(model.read(meta_file).decode("utf-8")) |
|
|
|
|
elif tfjs: # TF.js |
|
|
|
|
|
|
|
|
|
# TF.js |
|
|
|
|
elif tfjs: |
|
|
|
|
raise NotImplementedError("YOLOv8 TF.js inference is not currently supported.") |
|
|
|
|
elif paddle: # PaddlePaddle |
|
|
|
|
|
|
|
|
|
# PaddlePaddle |
|
|
|
|
elif paddle: |
|
|
|
|
LOGGER.info(f"Loading {w} for PaddlePaddle inference...") |
|
|
|
|
check_requirements("paddlepaddle-gpu" if cuda else "paddlepaddle") |
|
|
|
|
import paddle.inference as pdi # noqa |
|
|
|
@ -304,7 +335,9 @@ class AutoBackend(nn.Module): |
|
|
|
|
input_handle = predictor.get_input_handle(predictor.get_input_names()[0]) |
|
|
|
|
output_names = predictor.get_output_names() |
|
|
|
|
metadata = w.parents[1] / "metadata.yaml" |
|
|
|
|
elif ncnn: # NCNN |
|
|
|
|
|
|
|
|
|
# NCNN |
|
|
|
|
elif ncnn: |
|
|
|
|
LOGGER.info(f"Loading {w} for NCNN inference...") |
|
|
|
|
check_requirements("git+https://github.com/Tencent/ncnn.git" if ARM64 else "ncnn") # requires NCNN |
|
|
|
|
import ncnn as pyncnn |
|
|
|
@ -317,18 +350,21 @@ class AutoBackend(nn.Module): |
|
|
|
|
net.load_param(str(w)) |
|
|
|
|
net.load_model(str(w.with_suffix(".bin"))) |
|
|
|
|
metadata = w.parent / "metadata.yaml" |
|
|
|
|
elif triton: # NVIDIA Triton Inference Server |
|
|
|
|
|
|
|
|
|
# NVIDIA Triton Inference Server |
|
|
|
|
elif triton: |
|
|
|
|
check_requirements("tritonclient[all]") |
|
|
|
|
from ultralytics.utils.triton import TritonRemoteModel |
|
|
|
|
|
|
|
|
|
model = TritonRemoteModel(w) |
|
|
|
|
|
|
|
|
|
# Any other format (unsupported) |
|
|
|
|
else: |
|
|
|
|
from ultralytics.engine.exporter import export_formats |
|
|
|
|
|
|
|
|
|
raise TypeError( |
|
|
|
|
f"model='{w}' is not a supported model format. " |
|
|
|
|
"See https://docs.ultralytics.com/modes/predict for help." |
|
|
|
|
f"\n\n{export_formats()}" |
|
|
|
|
f"See https://docs.ultralytics.com/modes/predict for help.\n\n{export_formats()}" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Load external metadata YAML |
|
|
|
@ -380,21 +416,51 @@ class AutoBackend(nn.Module): |
|
|
|
|
if self.nhwc: |
|
|
|
|
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3) |
|
|
|
|
|
|
|
|
|
if self.pt or self.nn_module: # PyTorch |
|
|
|
|
# PyTorch |
|
|
|
|
if self.pt or self.nn_module: |
|
|
|
|
y = self.model(im, augment=augment, visualize=visualize, embed=embed) |
|
|
|
|
elif self.jit: # TorchScript |
|
|
|
|
|
|
|
|
|
# TorchScript |
|
|
|
|
elif self.jit: |
|
|
|
|
y = self.model(im) |
|
|
|
|
elif self.dnn: # ONNX OpenCV DNN |
|
|
|
|
|
|
|
|
|
# ONNX OpenCV DNN |
|
|
|
|
elif self.dnn: |
|
|
|
|
im = im.cpu().numpy() # torch to numpy |
|
|
|
|
self.net.setInput(im) |
|
|
|
|
y = self.net.forward() |
|
|
|
|
elif self.onnx: # ONNX Runtime |
|
|
|
|
|
|
|
|
|
# ONNX Runtime |
|
|
|
|
elif self.onnx: |
|
|
|
|
im = im.cpu().numpy() # torch to numpy |
|
|
|
|
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im}) |
|
|
|
|
elif self.xml: # OpenVINO |
|
|
|
|
|
|
|
|
|
# OpenVINO |
|
|
|
|
elif self.xml: |
|
|
|
|
im = im.cpu().numpy() # FP32 |
|
|
|
|
y = list(self.ov_compiled_model(im).values()) |
|
|
|
|
elif self.engine: # TensorRT |
|
|
|
|
|
|
|
|
|
if self.inference_mode in {"THROUGHPUT", "CUMULATIVE_THROUGHPUT"}: # optimized for larger batch-sizes |
|
|
|
|
n = im.shape[0] # number of images in batch |
|
|
|
|
results = [None] * n # preallocate list with None to match the number of images |
|
|
|
|
|
|
|
|
|
def callback(request, userdata): |
|
|
|
|
"""Places result in preallocated list using userdata index.""" |
|
|
|
|
results[userdata] = request.results |
|
|
|
|
|
|
|
|
|
# Create AsyncInferQueue, set the callback and start asynchronous inference for each input image |
|
|
|
|
async_queue = self.ov.runtime.AsyncInferQueue(self.ov_compiled_model) |
|
|
|
|
async_queue.set_callback(callback) |
|
|
|
|
for i in range(n): |
|
|
|
|
# Start async inference with userdata=i to specify the position in results list |
|
|
|
|
async_queue.start_async(inputs={self.input_name: im[i : i + 1]}, userdata=i) # keep image as BCHW |
|
|
|
|
async_queue.wait_all() # wait for all inference requests to complete |
|
|
|
|
y = [list(r.values()) for r in results][0] |
|
|
|
|
|
|
|
|
|
else: # inference_mode = "LATENCY", optimized for fastest first result at batch-size 1 |
|
|
|
|
y = list(self.ov_compiled_model(im).values()) |
|
|
|
|
|
|
|
|
|
# TensorRT |
|
|
|
|
elif self.engine: |
|
|
|
|
if self.dynamic and im.shape != self.bindings["images"].shape: |
|
|
|
|
i = self.model.get_binding_index("images") |
|
|
|
|
self.context.set_binding_shape(i, im.shape) # reshape if dynamic |
|
|
|
@ -407,7 +473,9 @@ class AutoBackend(nn.Module): |
|
|
|
|
self.binding_addrs["images"] = int(im.data_ptr()) |
|
|
|
|
self.context.execute_v2(list(self.binding_addrs.values())) |
|
|
|
|
y = [self.bindings[x].data for x in sorted(self.output_names)] |
|
|
|
|
elif self.coreml: # CoreML |
|
|
|
|
|
|
|
|
|
# CoreML |
|
|
|
|
elif self.coreml: |
|
|
|
|
im = im[0].cpu().numpy() |
|
|
|
|
im_pil = Image.fromarray((im * 255).astype("uint8")) |
|
|
|
|
# im = im.resize((192, 320), Image.BILINEAR) |
|
|
|
@ -426,12 +494,16 @@ class AutoBackend(nn.Module): |
|
|
|
|
y = list(y.values()) |
|
|
|
|
elif len(y) == 2: # segmentation model |
|
|
|
|
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto) |
|
|
|
|
elif self.paddle: # PaddlePaddle |
|
|
|
|
|
|
|
|
|
# PaddlePaddle |
|
|
|
|
elif self.paddle: |
|
|
|
|
im = im.cpu().numpy().astype(np.float32) |
|
|
|
|
self.input_handle.copy_from_cpu(im) |
|
|
|
|
self.predictor.run() |
|
|
|
|
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] |
|
|
|
|
elif self.ncnn: # NCNN |
|
|
|
|
|
|
|
|
|
# NCNN |
|
|
|
|
elif self.ncnn: |
|
|
|
|
mat_in = self.pyncnn.Mat(im[0].cpu().numpy()) |
|
|
|
|
ex = self.net.create_extractor() |
|
|
|
|
input_names, output_names = self.net.input_names(), self.net.output_names() |
|
|
|
@ -441,10 +513,14 @@ class AutoBackend(nn.Module): |
|
|
|
|
mat_out = self.pyncnn.Mat() |
|
|
|
|
ex.extract(output_name, mat_out) |
|
|
|
|
y.append(np.array(mat_out)[None]) |
|
|
|
|
elif self.triton: # NVIDIA Triton Inference Server |
|
|
|
|
|
|
|
|
|
# NVIDIA Triton Inference Server |
|
|
|
|
elif self.triton: |
|
|
|
|
im = im.cpu().numpy() # torch to numpy |
|
|
|
|
y = self.model(im) |
|
|
|
|
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) |
|
|
|
|
|
|
|
|
|
# TensorFlow (SavedModel, GraphDef, Lite, Edge TPU) |
|
|
|
|
else: |
|
|
|
|
im = im.cpu().numpy() |
|
|
|
|
if self.saved_model: # SavedModel |
|
|
|
|
y = self.model(im, training=False) if self.keras else self.model(im) |
|
|
|
|