You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

429 lines
23 KiB

# Ultralytics YOLO 🚀, GPL-3.0 license
import ast
import contextlib
import json
import platform
import zipfile
from collections import OrderedDict, namedtuple
from pathlib import Path
from urllib.parse import urlparse
import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from ultralytics.yolo.utils import LOGGER, ROOT, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_version, check_yaml
from ultralytics.yolo.utils.downloads import attempt_download_asset, is_url
from ultralytics.yolo.utils.ops import xywh2xyxy
def check_class_names(names):
# Check class names. Map imagenet class codes to human-readable names if required. Convert lists to dicts.
if isinstance(names, list): # names is a list
names = dict(enumerate(names)) # convert to dict
if isinstance(names, dict):
if not all(isinstance(k, int) for k in names.keys()): # convert string keys to int, i.e. '0' to 0
names = {int(k): v for k, v in names.items()}
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
map = yaml_load(ROOT / 'datasets/ImageNet.yaml')['map'] # human-readable names
names = {k: map[v] for k, v in names.items()}
return names
class AutoBackend(nn.Module):
def _apply_default_class_names(self, data):
with contextlib.suppress(Exception):
return yaml_load(check_yaml(data))['names']
return {i: f'class{i}' for i in range(999)} # return default if above errors
def __init__(self, weights='yolov8n.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True):
"""
MultiBackend class for python inference on various platforms using Ultralytics YOLO.
Args:
weights (str): The path to the weights file. Default: 'yolov8n.pt'
device (torch.device): The device to run the model on.
dnn (bool): Use OpenCV's DNN module for inference if True, defaults to False.
data (str), (Path): Additional data.yaml file for class names, optional
fp16 (bool): If True, use half precision. Default: False
fuse (bool): Whether to fuse the model or not. Default: True
Supported formats and their naming conventions:
| Format | Suffix |
|-----------------------|------------------|
| PyTorch | *.pt |
| TorchScript | *.torchscript |
| ONNX Runtime | *.onnx |
| ONNX OpenCV DNN | *.onnx dnn=True |
| OpenVINO | *.xml |
| CoreML | *.mlmodel |
| TensorRT | *.engine |
| TensorFlow SavedModel | *_saved_model |
| TensorFlow GraphDef | *.pb |
| TensorFlow Lite | *.tflite |
| TensorFlow Edge TPU | *_edgetpu.tflite |
| PaddlePaddle | *_paddle_model |
"""
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
nn_module = isinstance(weights, torch.nn.Module)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, triton = self._model_type(w)
fp16 &= pt or jit or onnx or engine or nn_module # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
stride = 32 # default stride
model = None # TODO: resolves ONNX inference, verify effect on other backends
cuda = torch.cuda.is_available() and device.type != 'cpu' # use CUDA
if not (pt or triton or nn_module):
w = attempt_download_asset(w) # download if not local
# NOTE: special case: in-memory pytorch model
if nn_module:
model = weights.to(device)
model = model.fuse() if fuse else model
names = model.module.names if hasattr(model, 'module') else model.names # get class names
stride = max(int(model.stride.max()), 32) # model stride
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
pt = True
elif pt: # PyTorch
from ultralytics.nn.tasks import attempt_load_weights
model = attempt_load_weights(weights if isinstance(weights, list) else w,
device=device,
inplace=True,
fuse=fuse)
stride = max(int(model.stride.max()), 32) # model stride
names = model.module.names if hasattr(model, 'module') else model.names # get class names
model.half() if fp16 else model.float()
self.model = model # explicitly assign for to(), cpu(), cuda(), half()
elif jit: # TorchScript
LOGGER.info(f'Loading {w} for TorchScript inference...')
extra_files = {'config.txt': ''} # model metadata
model = torch.jit.load(w, _extra_files=extra_files, map_location=device)
model.half() if fp16 else model.float()
if extra_files['config.txt']: # load metadata dict
d = json.loads(extra_files['config.txt'],
object_hook=lambda d: {int(k) if k.isdigit() else k: v
for k, v in d.items()})
stride, names = int(d['stride']), d['names']
elif dnn: # ONNX OpenCV DNN
LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
check_requirements('opencv-python>=4.5.4')
net = cv2.dnn.readNetFromONNX(w)
elif onnx: # ONNX Runtime
LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
import onnxruntime
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
session = onnxruntime.InferenceSession(w, providers=providers)
output_names = [x.name for x in session.get_outputs()]
meta = session.get_modelmeta().custom_metadata_map # metadata
if 'stride' in meta:
stride, names = int(meta['stride']), eval(meta['names'])
elif xml: # OpenVINO
LOGGER.info(f'Loading {w} for OpenVINO inference...')
check_requirements('openvino') # requires openvino-dev: https://pypi.org/project/openvino-dev/
from openvino.runtime import Core, Layout, get_batch # noqa
ie = Core()
if not Path(w).is_file(): # if not *.xml
w = next(Path(w).glob('*.xml')) # get *.xml file from *_openvino_model dir
network = ie.read_model(model=w, weights=Path(w).with_suffix('.bin'))
if network.get_parameters()[0].get_layout().empty:
network.get_parameters()[0].set_layout(Layout('NCHW'))
batch_dim = get_batch(network)
if batch_dim.is_static:
batch_size = batch_dim.get_length()
executable_network = ie.compile_model(network, device_name='CPU') # device_name="MYRIAD" for Intel NCS2
elif engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...')
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
if device.type == 'cpu':
device = torch.device('cuda:0')
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
logger = trt.Logger(trt.Logger.INFO)
# Read file
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
meta_len = int.from_bytes(f.read(4), byteorder='little') # read metadata length
meta = json.loads(f.read(meta_len).decode('utf-8')) # read metadata
model = runtime.deserialize_cuda_engine(f.read()) # read engine
context = model.create_execution_context()
bindings = OrderedDict()
output_names = []
fp16 = False # default updated below
dynamic = False
for i in range(model.num_bindings):
name = model.get_binding_name(i)
dtype = trt.nptype(model.get_binding_dtype(i))
if model.binding_is_input(i):
if -1 in tuple(model.get_binding_shape(i)): # dynamic
dynamic = True
context.set_binding_shape(i, tuple(model.get_profile_shape(0, i)[2]))
if dtype == np.float16:
fp16 = True
else: # output
output_names.append(name)
shape = tuple(context.get_binding_shape(i))
im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
stride, names = int(meta['stride']), meta['names']
elif coreml: # CoreML
LOGGER.info(f'Loading {w} for CoreML inference...')
import coremltools as ct
model = ct.models.MLModel(w)
elif saved_model: # TF SavedModel
LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
import tensorflow as tf
keras = False # assume TF1 saved_model
model = tf.keras.models.load_model(w) if keras else tf.saved_model.load(w)
w = Path(w) / 'metadata.yaml'
elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
import tensorflow as tf
def wrap_frozen_graph(gd, inputs, outputs):
x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=''), []) # wrapped
ge = x.graph.as_graph_element
return x.prune(tf.nest.map_structure(ge, inputs), tf.nest.map_structure(ge, outputs))
def gd_outputs(gd):
name_list, input_list = [], []
for node in gd.node: # tensorflow.core.framework.node_def_pb2.NodeDef
name_list.append(node.name)
input_list.extend(node.input)
return sorted(f'{x}:0' for x in list(set(name_list) - set(input_list)) if not x.startswith('NoOp'))
gd = tf.Graph().as_graph_def() # TF GraphDef
with open(w, 'rb') as f:
gd.ParseFromString(f.read())
frozen_func = wrap_frozen_graph(gd, inputs='x:0', outputs=gd_outputs(gd))
elif tflite or edgetpu: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
from tflite_runtime.interpreter import Interpreter, load_delegate
except ImportError:
import tensorflow as tf
Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate
if edgetpu: # TF Edge TPU https://coral.ai/software/#edgetpu-runtime
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
delegate = {
'Linux': 'libedgetpu.so.1',
'Darwin': 'libedgetpu.1.dylib',
'Windows': 'edgetpu.dll'}[platform.system()]
interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
else: # TFLite
LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
interpreter = Interpreter(model_path=w) # load TFLite model
interpreter.allocate_tensors() # allocate
input_details = interpreter.get_input_details() # inputs
output_details = interpreter.get_output_details() # outputs
# load metadata
with contextlib.suppress(zipfile.BadZipFile):
with zipfile.ZipFile(w, 'r') as model:
meta_file = model.namelist()[0]
meta = ast.literal_eval(model.read(meta_file).decode('utf-8'))
stride, names = int(meta['stride']), meta['names']
elif tfjs: # TF.js
raise NotImplementedError('YOLOv8 TF.js inference is not supported')
elif paddle: # PaddlePaddle
LOGGER.info(f'Loading {w} for PaddlePaddle inference...')
check_requirements('paddlepaddle-gpu' if cuda else 'paddlepaddle')
import paddle.inference as pdi
if not Path(w).is_file(): # if not *.pdmodel
w = next(Path(w).rglob('*.pdmodel')) # get *.pdmodel file from *_paddle_model dir
weights = Path(w).with_suffix('.pdiparams')
config = pdi.Config(str(w), str(weights))
if cuda:
config.enable_use_gpu(memory_pool_init_size_mb=2048, device_id=0)
predictor = pdi.create_predictor(config)
input_handle = predictor.get_input_handle(predictor.get_input_names()[0])
output_names = predictor.get_output_names()
elif triton: # NVIDIA Triton Inference Server
LOGGER.info('Triton Inference Server not supported...')
'''
TODO:
check_requirements('tritonclient[all]')
from utils.triton import TritonRemoteModel
model = TritonRemoteModel(url=w)
nhwc = model.runtime.startswith("tensorflow")
'''
else:
from ultralytics.yolo.engine.exporter import EXPORT_FORMATS_TABLE
raise TypeError(f"model='{w}' is not a supported model format. "
'See https://docs.ultralytics.com/tasks/detection/#export for help.'
f'\n\n{EXPORT_FORMATS_TABLE}')
# Load external metadata YAML
if xml or saved_model or paddle:
metadata = Path(w).parent / 'metadata.yaml'
if metadata.exists():
metadata = yaml_load(metadata)
stride, names = int(metadata['stride']), metadata['names'] # load metadata
else:
LOGGER.warning(f"WARNING ⚠ Metadata not found at '{metadata}'")
# Check names
if 'names' not in locals(): # names missing
names = self._apply_default_class_names(data)
names = check_class_names(names)
self.__dict__.update(locals()) # assign all variables to self
def forward(self, im, augment=False, visualize=False):
"""
Runs inference on the YOLOv8 MultiBackend model.
Args:
im (torch.Tensor): The image tensor to perform inference on.
augment (bool): whether to perform data augmentation during inference, defaults to False
visualize (bool): whether to visualize the output predictions, defaults to False
Returns:
(tuple): Tuple containing the raw output tensor, and the processed output for visualization (if visualize=True)
"""
b, ch, h, w = im.shape # batch, channel, height, width
if self.fp16 and im.dtype != torch.float16:
im = im.half() # to FP16
if self.nhwc:
im = im.permute(0, 2, 3, 1) # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.pt or self.nn_module: # PyTorch
y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
elif self.jit: # TorchScript
y = self.model(im)
elif self.dnn: # ONNX OpenCV DNN
im = im.cpu().numpy() # torch to numpy
self.net.setInput(im)
y = self.net.forward()
elif self.onnx: # ONNX Runtime
im = im.cpu().numpy() # torch to numpy
y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
elif self.xml: # OpenVINO
im = im.cpu().numpy() # FP32
y = list(self.executable_network([im]).values())
elif self.engine: # TensorRT
if self.dynamic and im.shape != self.bindings['images'].shape:
i = self.model.get_binding_index('images')
self.context.set_binding_shape(i, im.shape) # reshape if dynamic
self.bindings['images'] = self.bindings['images']._replace(shape=im.shape)
for name in self.output_names:
i = self.model.get_binding_index(name)
self.bindings[name].data.resize_(tuple(self.context.get_binding_shape(i)))
s = self.bindings['images'].shape
assert im.shape == s, f"input size {im.shape} {'>' if self.dynamic else 'not equal to'} max model size {s}"
self.binding_addrs['images'] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = [self.bindings[x].data for x in sorted(self.output_names)]
elif self.coreml: # CoreML
im = im.cpu().numpy()
im = Image.fromarray((im[0] * 255).astype('uint8'))
# im = im.resize((192, 320), Image.ANTIALIAS)
y = self.model.predict({'image': im}) # coordinates are xywh normalized
if 'confidence' in y:
box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
elif len(y) == 1: # classification model
y = list(y.values())
elif len(y) == 2: # segmentation model
y = list(reversed(y.values())) # reversed for segmentation models (pred, proto)
elif self.paddle: # PaddlePaddle
im = im.cpu().numpy().astype(np.float32)
self.input_handle.copy_from_cpu(im)
self.predictor.run()
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
elif self.triton: # NVIDIA Triton Inference Server
y = self.model(im)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.cpu().numpy()
if self.saved_model: # SavedModel
y = self.model(im, training=False) if self.keras else self.model(im)
if not isinstance(y, list):
y = [y]
elif self.pb: # GraphDef
y = self.frozen_func(x=self.tf.constant(im))
if len(y) == 2 and len(self.names) == 999: # segments and names not defined
ip, ib = (0, 1) if len(y[0].shape) == 4 else (1, 0) # index of protos, boxes
nc = y[ib].shape[1] - y[ip].shape[3] - 4 # y = (1, 160, 160, 32), (1, 116, 8400)
self.names = {i: f'class{i}' for i in range(nc)}
else: # Lite or Edge TPU
input = self.input_details[0]
int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
if int8:
scale, zero_point = input['quantization']
im = (im / scale + zero_point).astype(np.uint8) # de-scale
self.interpreter.set_tensor(input['index'], im)
self.interpreter.invoke()
y = []
for output in self.output_details:
x = self.interpreter.get_tensor(output['index'])
if int8:
scale, zero_point = output['quantization']
x = (x.astype(np.float32) - zero_point) * scale # re-scale
y.append(x)
# TF segment fixes: export is reversed vs ONNX export and protos are transposed
if len(y) == 2: # segment with (det, proto) output order reversed
if len(y[1].shape) != 4:
y = list(reversed(y)) # should be y = (1, 116, 8400), (1, 160, 160, 32)
y[1] = np.transpose(y[1], (0, 3, 1, 2)) # should be y = (1, 116, 8400), (1, 32, 160, 160)
y = [x if isinstance(x, np.ndarray) else x.numpy() for x in y]
# y[0][..., :4] *= [w, h, w, h] # xywh normalized to pixels
# for x in y:
# print(type(x), len(x)) if isinstance(x, (list, tuple)) else print(type(x), x.shape) # debug shapes
if isinstance(y, (list, tuple)):
return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
else:
return self.from_numpy(y)
def from_numpy(self, x):
"""
Convert a numpy array to a tensor.
Args:
x (np.ndarray): The array to be converted.
Returns:
(torch.Tensor): The converted tensor
"""
return torch.tensor(x).to(self.device) if isinstance(x, np.ndarray) else x
def warmup(self, imgsz=(1, 3, 640, 640)):
"""
Warm up the model by running one forward pass with a dummy input.
Args:
imgsz (tuple): The shape of the dummy input tensor in the format (batch_size, channels, height, width)
Returns:
(None): This method runs the forward pass and don't return any value
"""
warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb, self.triton, self.nn_module
if any(warmup_types) and (self.device.type != 'cpu' or self.triton):
im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
for _ in range(2 if self.jit else 1): #
self.forward(im) # warmup
@staticmethod
def _model_type(p='path/to/model.pt'):
"""
This function takes a path to a model file and returns the model type
Args:
p: path to the model file. Defaults to path/to/model.pt
"""
# Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
# types = [pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle]
from ultralytics.yolo.engine.exporter import export_formats
sf = list(export_formats().Suffix) # export suffixes
if not is_url(p, check=False) and not isinstance(p, str):
check_suffix(p, sf) # checks
url = urlparse(p) # if url may be Triton inference server
types = [s in Path(p).name for s in sf]
types[8] &= not types[9] # tflite &= not edgetpu
triton = not any(types) and all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
return types + [triton]