`ultralytics 8.0.195` NVIDIA Triton Inference Server support (#5257)

Co-authored-by: TheConstant3 <46416203+TheConstant3@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5142/head^2 v8.0.195
Glenn Jocher 1 year ago committed by GitHub
parent 40e3923cfc
commit c7aa83da31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      .github/workflows/ci.yaml
  2. 2
      docker/Dockerfile
  3. 2
      docs/guides/azureml-quickstart.md
  4. 1
      docs/guides/index.md
  5. 76
      docs/guides/raspberry-pi.md
  6. 137
      docs/guides/triton-inference-server.md
  7. 2
      docs/modes/export.md
  8. 9
      docs/reference/utils/triton.md
  9. 2
      docs/tasks/classify.md
  10. 2
      docs/tasks/detect.md
  11. 2
      docs/tasks/pose.md
  12. 2
      docs/tasks/segment.md
  13. 2
      mkdocs.yml
  14. 71
      tests/test_python.py
  15. 2
      ultralytics/__init__.py
  16. 13
      ultralytics/engine/model.py
  17. 15
      ultralytics/models/fastsam/predict.py
  18. 20
      ultralytics/nn/autobackend.py
  19. 2
      ultralytics/utils/__init__.py
  20. 1
      ultralytics/utils/metrics.py
  21. 86
      ultralytics/utils/triton.py

@ -241,7 +241,7 @@ jobs:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
Conda:
if: github.repository == 'ultralytics/ultralytics' && (github.event_name == 'schedule-disabled' || github.event.inputs.conda == 'true')
if: github.repository == 'ultralytics/ultralytics' && (github.event_name == 'schedule' || github.event.inputs.conda == 'true')
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false

@ -3,7 +3,7 @@
# Image is CUDA-optimized for YOLOv8 single/multi-GPU training and inference
# Start FROM PyTorch image https://hub.docker.com/r/pytorch/pytorch or nvcr.io/nvidia/pytorch:23.03-py3
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
RUN pip install --no-cache nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com
# Downloads to user config dir

@ -77,7 +77,7 @@ Train a detection model for 10 epochs with an initial learning_rate of 0.01:
yolo train data=coco128.yaml model=yolov8n.pt epochs=10 lr0=0.01
```
You can find more [instructions to use the Ultralytics cli here](https://docs.ultralytics.com/quickstart/#use-ultralytics-with-cli).
You can find more [instructions to use the Ultralytics CLI here](https://docs.ultralytics.com/quickstart/#use-ultralytics-with-cli).
## Quickstart from a Notebook

@ -22,6 +22,7 @@ Here's a compilation of in-depth guides to help you master different aspects of
* [Conda Quickstart](conda-quickstart.md) 🚀 NEW: Step-by-step guide to setting up a [Conda](https://anaconda.org/conda-forge/ultralytics) environment for Ultralytics. Learn how to install and start using the Ultralytics package efficiently with Conda.
* [Docker Quickstart](docker-quickstart.md) 🚀 NEW: Complete guide to setting up and using Ultralytics YOLO models with [Docker](https://hub.docker.com/r/ultralytics/ultralytics). Learn how to install Docker, manage GPU support, and run YOLO models in isolated containers for consistent development and deployment.
* [Raspberry Pi](raspberry-pi.md) 🚀 NEW: Quickstart tutorial to run YOLO models to the latest Raspberry Pi hardware.
* [Triton Inference Server Integration](triton-inference-server.md) 🚀 NEW: Dive into the integration of Ultralytics YOLOv8 with NVIDIA's Triton Inference Server for scalable and efficient deep learning inference deployments.
## Contribute to Our Guides

@ -37,47 +37,25 @@ You should see a video feed from your camera.
This guide offers you the flexibility to start with either [YOLOv5](https://github.com/ultralytics/yolov5) or [YOLOv8](https://github.com/ultralytics/ultralytics). Both versions have their unique advantages and use-cases. The choice is yours, but remember, the guide's aim is not just quick setup but also a robust foundation for your future work in object detection.
## Hardware Specifics: Raspberry Pi 3 vs Raspberry Pi 4
Raspberry Pi 3 and Raspberry Pi 4 have distinct hardware specifications, and the YOLO installation and configuration process can vary slightly depending on which model you're using.
### Raspberry Pi 3
- **CPU**: 1.2GHz Quad-Core ARM Cortex-A53
- **RAM**: 1GB LPDDR2
- **USB Ports**: 4 x USB 2.0
- **Network**: Ethernet & Wi-Fi 802.11n
- **Performance**: Generally slower, may require lighter YOLO models for real-time processing
- **Power Requirement**: 2.5A power supply
- **Official Documentation**: [Raspberry Pi 3 Documentation](https://www.raspberrypi.org/documentation/hardware/raspberrypi/bcm2837/README.md)
### Raspberry Pi 4
- **CPU**: 1.5GHz Quad-core 64-bit ARM Cortex-A72 CPU
- **RAM**: Options of 2GB, 4GB or 8GB LPDDR4
- **USB Ports**: 2 x USB 2.0, 2 x USB 3.0
- **Network**: Gigabit Ethernet & Wi-Fi 802.11ac
- **Performance**: Faster, capable of running more complex YOLO models in real-time
- **Power Requirement**: 3.0A USB-C power supply
- **Official Documentation**: [Raspberry Pi 4 Documentation](https://www.raspberrypi.org/documentation/hardware/raspberrypi/bcm2711/README.md)
### Raspberry Pi 5
- **CPU**: 2.4GHz Quad-core 64-bit Arm Cortex-A76 CPU
- **GPU**: VideoCore VII, supporting OpenGL ES 3.1, Vulkan 1.2
- **Display Output**: Dual 4Kp60 HDMI
- **Decoder**: 4Kp60 HEVC
- **Network**: Gigabit Ethernet with PoE+ support, Dual-band 802.11ac Wi-Fi®, Bluetooth 5.0 / BLE
- **USB Ports**: 2 x USB 3.0, 2 x USB 2.0
- **Other Features**: High-speed microSD card interface with SDR104 mode, 2 × 4-lane MIPI camera/display transceivers, PCIe 2.0 x1 interface, standard 40-pin GPIO header, real-time clock, power button
- **Power Requirement**: Specifics not yet available, expected to require a higher amperage supply
- **Official Documentation**: [Raspberry Pi 5 Documentation](https://www.raspberrypi.com/news/introducing-raspberry-pi-5/)
## Hardware Specifics: At a Glance
To assist you in making an informed hardware decision, we've summarized the key hardware specifics of Raspberry Pi 3, 4, and 5 in the table below:
| Feature | Raspberry Pi 3 | Raspberry Pi 4 | Raspberry Pi 5 |
|----------------------------|------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------|----------------------------------------------------------------------|
| **CPU** | 1.2GHz Quad-Core ARM Cortex-A53 | 1.5GHz Quad-core 64-bit ARM Cortex-A72 | 2.4GHz Quad-core 64-bit Arm Cortex-A76 |
| **RAM** | 1GB LPDDR2 | 2GB, 4GB or 8GB LPDDR4 | *Details not yet available* |
| **USB Ports** | 4 x USB 2.0 | 2 x USB 2.0, 2 x USB 3.0 | 2 x USB 3.0, 2 x USB 2.0 |
| **Network** | Ethernet & Wi-Fi 802.11n | Gigabit Ethernet & Wi-Fi 802.11ac | Gigabit Ethernet with PoE+ support, Dual-band 802.11ac Wi-Fi® |
| **Performance** | Slower, may require lighter YOLO models | Faster, can run complex YOLO models | *Details not yet available* |
| **Power Requirement** | 2.5A power supply | 3.0A USB-C power supply | *Details not yet available* |
| **Official Documentation** | [Link](https://www.raspberrypi.org/documentation/hardware/raspberrypi/bcm2837/README.md) | [Link](https://www.raspberrypi.org/documentation/hardware/raspberrypi/bcm2711/README.md) | [Link](https://www.raspberrypi.com/news/introducing-raspberry-pi-5/) |
Please make sure to follow the instructions specific to your Raspberry Pi model to ensure a smooth setup process.
## Quick Start with YOLOv5
This section outlines how to set up YOLOv5 on a Raspberry Pi 3 or 4 with a Pi Camera. These steps are designed to be compatible with the libcamera camera stack introduced in Raspberry Pi OS Bullseye.
This section outlines how to set up YOLOv5 on a Raspberry Pi with a Pi Camera. These steps are designed to be compatible with the libcamera camera stack introduced in Raspberry Pi OS Bullseye.
### Install Necessary Packages
@ -171,7 +149,7 @@ Follow this section if you are interested in setting up YOLOv8 instead. The step
sudo apt-get autoremove -y
```
2. Install YOLOv8:
2. Install the `ultralytics` Python package:
```bash
pip3 install ultralytics
@ -183,28 +161,6 @@ Follow this section if you are interested in setting up YOLOv8 instead. The step
sudo reboot
```
### Modify `build.py`
Just like YOLOv5, YOLOv8 also needs minor modifications to accept TCP streams.
1. Open `build.py` located in the Ultralytics package folder:
```bash
sudo nano /home/pi/.local/lib/pythonX.X/site-packages/ultralytics/build.py
```
2. Find and modify the `is_url` line to accept TCP streams:
```python
is_url = source.lower().startswith(('rtsp://', 'rtmp://', 'http://', 'https://', 'tcp://'))
```
3. Save and exit:
```bash
CTRL + O -> ENTER -> CTRL + X
```
### Initiate TCP Stream with Libcamera
1. Start the TCP stream:
@ -231,7 +187,7 @@ while True:
## Next Steps
Congratulations on successfully setting up YOLO on your Raspberry Pi! For further learning and support, visit [Ultralytics](https://ultralytics.com/) and [KashmirWorldFoundation](https://www.kashmirworldfoundation.org/).
Congratulations on successfully setting up YOLO on your Raspberry Pi! For further learning and support, visit [Ultralytics](https://ultralytics.com/) and [Kashmir World Foundation](https://www.kashmirworldfoundation.org/).
## Acknowledgements and Citations

@ -0,0 +1,137 @@
---
comments: true
description: A step-by-step guide on integrating Ultralytics YOLOv8 with Triton Inference Server for scalable and high-performance deep learning inference deployments.
keywords: YOLOv8, Triton Inference Server, ONNX, Deep Learning Deployment, Scalable Inference, Ultralytics, NVIDIA, Object Detection, Cloud Inferencing
---
# Triton Inference Server with Ultralytics YOLOv8
The [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server) (formerly known as TensorRT Inference Server) is an open-source software solution developed by NVIDIA. It provides a cloud inferencing solution optimized for NVIDIA GPUs. Triton simplifies the deployment of AI models at scale in production. Integrating Ultralytics YOLOv8 with Triton Inference Server allows you to deploy scalable, high-performance deep learning inference workloads. This guide provides steps to set up and test the integration.
<p align="center">
<br>
<iframe width="720" height="405" src="https://www.youtube.com/embed/NQDtfSi5QF4"
title="Getting Started with NVIDIA Triton Inference Server" frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
allowfullscreen>
</iframe>
<br>
<strong>Watch:</strong> Getting Started with NVIDIA Triton Inference Server.
</p>
## What is Triton Inference Server?
Triton Inference Server is designed to deploy a variety of AI models in production. It supports a wide range of deep learning and machine learning frameworks, including TensorFlow, PyTorch, ONNX Runtime, and many others. Its primary use cases are:
- Serving multiple models from a single server instance.
- Dynamic model loading and unloading without server restart.
- Ensemble inferencing, allowing multiple models to be used together to achieve results.
- Model versioning for A/B testing and rolling updates.
## Prerequisites
Ensure you have the following prerequisites before proceeding:
- Docker installed on your machine.
- Install `tritonclient`:
```bash
pip install tritonclient[all]
```
## Exporting YOLOv8 to ONNX Format
Before deploying the model on Triton, it must be exported to the ONNX format. ONNX (Open Neural Network Exchange) is a format that allows models to be transferred between different deep learning frameworks. Use the `export` function from the `YOLO` class:
```python
from ultralytics import YOLO
# Load a model
model = YOLO('yolov8n.pt') # load an official model
# Export the model
onnx_file = model.export(format='onnx', dynamic=True)
```
## Setting Up Triton Model Repository
The Triton Model Repository is a storage location where Triton can access and load models.
1. Create the necessary directory structure:
```python
from pathlib import Path
# Define paths
triton_repo_path = Path('tmp') / 'triton_repo'
triton_model_path = triton_repo_path / 'yolo'
# Create directories
(triton_model_path / '1').mkdir(parents=True, exist_ok=True)
```
2. Move the exported ONNX model to the Triton repository:
```python
from pathlib import Path
# Move ONNX model to Triton Model path
Path(onnx_file).rename(triton_model_path / '1' / 'model.onnx')
# Create config file
(triton_model_path / 'config.pdtxt').touch()
```
## Running Triton Inference Server
Run the Triton Inference Server using Docker:
```python
import subprocess
import time
from tritonclient.http import InferenceServerClient
# Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
tag = 'nvcr.io/nvidia/tritonserver:23.09-py3' # 6.4 GB
# Pull the image
subprocess.call(f'docker pull {tag}', shell=True)
# Run the Triton server and capture the container ID
container_id = subprocess.check_output(
f'docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models',
shell=True).decode('utf-8').strip()
# Wait for the Triton server to start
triton_client = InferenceServerClient(url='localhost:8000', verbose=False, ssl=False)
# Wait until model is ready
for _ in range(10):
with contextlib.suppress(Exception):
assert triton_client.is_model_ready(model_name)
break
time.sleep(1)
```
Then run inference using the Triton Server model:
```python
from ultralytics import YOLO
# Load the Triton Server model
model = YOLO(f'http://localhost:8000/yolo', task='detect')
# Run inference on the server
results = model('path/to/image.jpg')
```
Cleanup the container:
```python
# Kill and remove the container at the end of the test
subprocess.call(f'docker kill {container_id}', shell=True)
```
---
By following the above steps, you can deploy and run Ultralytics YOLOv8 models efficiently on Triton Inference Server, providing a scalable and high-performance solution for deep learning inference tasks. If you face any issues or have further queries, refer to the [official Triton documentation](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html) or reach out to the Ultralytics community for support.

@ -57,7 +57,7 @@ Export a YOLOv8n model to a different format like ONNX or TensorRT. See Argument
# Load a model
model = YOLO('yolov8n.pt') # load an official model
model = YOLO('path/to/best.pt') # load a custom trained
model = YOLO('path/to/best.pt') # load a custom trained model
# Export the model
model.export(format='onnx')

@ -0,0 +1,9 @@
# Reference for `ultralytics/utils/triton.py`
!!! note
Full source code for this file is available at [https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/triton.py](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/utils/triton.py). Help us fix any issues you see by submitting a [Pull Request](https://docs.ultralytics.com/help/contributing/) 🛠. Thank you 🙏!
---
## ::: ultralytics.utils.triton.TritonRemoteModel
<br><br>

@ -140,7 +140,7 @@ Export a YOLOv8n-cls model to a different format like ONNX, CoreML, etc.
# Load a model
model = YOLO('yolov8n-cls.pt') # load an official model
model = YOLO('path/to/best.pt') # load a custom trained
model = YOLO('path/to/best.pt') # load a custom trained model
# Export the model
model.export(format='onnx')

@ -152,7 +152,7 @@ Export a YOLOv8n model to a different format like ONNX, CoreML, etc.
# Load a model
model = YOLO('yolov8n.pt') # load an official model
model = YOLO('path/to/best.pt') # load a custom trained
model = YOLO('path/to/best.pt') # load a custom trained model
# Export the model
model.export(format='onnx')

@ -156,7 +156,7 @@ Export a YOLOv8n Pose model to a different format like ONNX, CoreML, etc.
# Load a model
model = YOLO('yolov8n-pose.pt') # load an official model
model = YOLO('path/to/best.pt') # load a custom trained
model = YOLO('path/to/best.pt') # load a custom trained model
# Export the model
model.export(format='onnx')

@ -157,7 +157,7 @@ Export a YOLOv8n-seg model to a different format like ONNX, CoreML, etc.
# Load a model
model = YOLO('yolov8n-seg.pt') # load an official model
model = YOLO('path/to/best.pt') # load a custom trained
model = YOLO('path/to/best.pt') # load a custom trained model
# Export the model
model.export(format='onnx')

@ -223,6 +223,7 @@ nav:
- Conda Quickstart: guides/conda-quickstart.md
- Docker Quickstart: guides/docker-quickstart.md
- Raspberry Pi: guides/raspberry-pi.md
- Triton Inference Server: guides/triton-inference-server.md
- Integrations:
- integrations/index.md
- OpenVINO: integrations/openvino.md
@ -390,6 +391,7 @@ nav:
- plotting: reference/utils/plotting.md
- tal: reference/utils/tal.md
- torch_utils: reference/utils/torch_utils.md
- triton: reference/utils/triton.md
- tuner: reference/utils/tuner.md
- Help:

@ -15,7 +15,7 @@ from ultralytics import RTDETR, YOLO
from ultralytics.cfg import TASK2DATA
from ultralytics.data.build import load_inference_source
from ultralytics.utils import (ASSETS, DEFAULT_CFG, DEFAULT_CFG_PATH, LINUX, MACOS, ONLINE, ROOT, WEIGHTS_DIR, WINDOWS,
is_dir_writeable)
checks, is_dir_writeable)
from ultralytics.utils.downloads import download
from ultralytics.utils.torch_utils import TORCH_1_9
@ -343,17 +343,14 @@ def test_utils_init():
def test_utils_checks():
from ultralytics.utils.checks import (check_imgsz, check_imshow, check_requirements, check_version,
check_yolov5u_filename, git_describe, print_args)
check_yolov5u_filename('yolov5n.pt')
# check_imshow(warn=True)
git_describe(ROOT)
check_requirements() # check requirements.txt
check_imgsz([600, 600], max_dim=1)
check_imshow()
check_version('ultralytics', '8.0.0')
print_args()
checks.check_yolov5u_filename('yolov5n.pt')
checks.git_describe(ROOT)
checks.check_requirements() # check requirements.txt
checks.check_imgsz([600, 600], max_dim=1)
checks.check_imshow()
checks.check_version('ultralytics', '8.0.0')
checks.print_args()
# checks.check_imshow(warn=True)
def test_utils_benchmarks():
@ -451,3 +448,53 @@ def test_hub():
export_fmts_hub()
logout()
smart_request('GET', 'http://github.com', progress=True)
@pytest.mark.slow
@pytest.mark.skipif(not ONLINE, reason='environment is offline')
def test_triton():
checks.check_requirements('tritonclient[all]')
import subprocess
import time
from tritonclient.http import InferenceServerClient # noqa
# Create variables
model_name = 'yolo'
triton_repo_path = TMP / 'triton_repo'
triton_model_path = triton_repo_path / model_name
# Export model to ONNX
f = YOLO(MODEL).export(format='onnx', dynamic=True)
# Prepare Triton repo
(triton_model_path / '1').mkdir(parents=True, exist_ok=True)
Path(f).rename(triton_model_path / '1' / 'model.onnx')
(triton_model_path / 'config.pdtxt').touch()
# Define image https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver
tag = 'nvcr.io/nvidia/tritonserver:23.09-py3' # 6.4 GB
# Pull the image
subprocess.call(f'docker pull {tag}', shell=True)
# Run the Triton server and capture the container ID
container_id = subprocess.check_output(
f'docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models',
shell=True).decode('utf-8').strip()
# Wait for the Triton server to start
triton_client = InferenceServerClient(url='localhost:8000', verbose=False, ssl=False)
# Wait until model is ready
for _ in range(10):
with contextlib.suppress(Exception):
assert triton_client.is_model_ready(model_name)
break
time.sleep(1)
# Check Triton inference
YOLO(f'http://localhost:8000/{model_name}', 'detect')(SOURCE) # exported model inference
# Kill and remove the container at the end of the test
subprocess.call(f'docker kill {container_id}', shell=True)

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.194'
__version__ = '8.0.195'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM

@ -81,6 +81,12 @@ class Model(nn.Module):
self.session = HUBTrainingSession(model)
model = self.session.model_file
# Check if Triton Server model
elif self.is_triton_model(model):
self.model = model
self.task = task
return
# Load or create new YOLO model
suffix = Path(model).suffix
if not suffix and Path(model).stem in GITHUB_ASSETS_STEMS:
@ -94,6 +100,13 @@ class Model(nn.Module):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, **kwargs)
@staticmethod
def is_triton_model(model):
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
from urllib.parse import urlsplit
url = urlsplit(model)
return url.netloc and url.path and url.scheme in {'http', 'grfc'}
@staticmethod
def is_hub_model(model):
"""Check if the provided model is a HUB model."""

@ -15,13 +15,14 @@ class FastSAMPredictor(DetectionPredictor):
self.args.task = 'segment'
def postprocess(self, preds, img, orig_imgs):
p = ops.non_max_suppression(preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=len(self.model.names),
classes=self.args.classes)
p = ops.non_max_suppression(
preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=1, # set to 1 class since SAM has no class predictions
classes=self.args.classes)
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1)

@ -7,7 +7,6 @@ import platform
import zipfile
from collections import OrderedDict, namedtuple
from pathlib import Path
from urllib.parse import urlparse
import cv2
import numpy as np
@ -32,8 +31,8 @@ def check_class_names(names):
raise KeyError(f'{n}-class dataset requires class indices 0-{n - 1}, but you have invalid class indices '
f'{min(names.keys())}-{max(names.keys())} defined in your dataset YAML.')
if isinstance(names[0], str) and names[0].startswith('n0'): # imagenet class codes, i.e. 'n01440764'
map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names
names = {k: map[v] for k, v in names.items()}
names_map = yaml_load(ROOT / 'cfg/datasets/ImageNet.yaml')['map'] # human-readable names
names = {k: names_map[v] for k, v in names.items()}
return names
@ -274,13 +273,9 @@ class AutoBackend(nn.Module):
net.load_model(str(w.with_suffix('.bin')))
metadata = w.parent / 'metadata.yaml'
elif triton: # NVIDIA Triton Inference Server
"""TODO
check_requirements('tritonclient[all]')
from utils.triton import TritonRemoteModel
model = TritonRemoteModel(url=w)
nhwc = model.runtime.startswith("tensorflow")
"""
raise NotImplementedError('Triton Inference Server is not currently supported.')
from ultralytics.utils.triton import TritonRemoteModel
model = TritonRemoteModel(w)
else:
from ultralytics.engine.exporter import export_formats
raise TypeError(f"model='{w}' is not a supported model format. "
@ -395,6 +390,7 @@ class AutoBackend(nn.Module):
ex.extract(output_name, mat_out)
y.append(np.array(mat_out)[None])
elif self.triton: # NVIDIA Triton Inference Server
im = im.cpu().numpy() # torch to numpy
y = self.model(im)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
im = im.cpu().numpy()
@ -498,6 +494,8 @@ class AutoBackend(nn.Module):
if any(types):
triton = False
else:
url = urlparse(p) # if url may be Triton inference server
triton = all([any(s in url.scheme for s in ['http', 'grpc']), url.netloc])
from urllib.parse import urlsplit
url = urlsplit(p)
triton = url.netloc and url.path and url.scheme in {'http', 'grfc'}
return types + [triton]

@ -705,7 +705,7 @@ def remove_colorstr(input_string):
>>> remove_colorstr(colorstr('blue', 'bold', 'hello world'))
>>> 'hello world'
"""
ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
ansi_escape = re.compile(r'\x1B(?:[@-Z\\\-_]|\[[0-9]*[ -/]*[@-~])')
return ansi_escape.sub('', input_string)

@ -2,6 +2,7 @@
"""
Model validation metrics
"""
import math
import warnings
from pathlib import Path

@ -0,0 +1,86 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
from typing import List
from urllib.parse import urlsplit
import numpy as np
class TritonRemoteModel:
"""Client for interacting with a remote Triton Inference Server model.
Attributes:
endpoint (str): The name of the model on the Triton server.
url (str): The URL of the Triton server.
triton_client: The Triton client (either HTTP or gRPC).
InferInput: The input class for the Triton client.
InferRequestedOutput: The output request class for the Triton client.
input_formats (List[str]): The data types of the model inputs.
np_input_formats (List[type]): The numpy data types of the model inputs.
input_names (List[str]): The names of the model inputs.
output_names (List[str]): The names of the model outputs.
"""
def __init__(self, url: str, endpoint: str = '', scheme: str = ''):
"""
Initialize the TritonRemoteModel.
Arguments may be provided individually or parsed from a collective 'url' argument of the form
<scheme>://<netloc>/<endpoint>/<task_name>
Args:
url (str): The URL of the Triton server.
endpoint (str): The name of the model on the Triton server.
scheme (str): The communication scheme ('http' or 'grpc').
"""
if not endpoint and not scheme: # Parse all args from URL string
splits = urlsplit(url)
endpoint = splits.path.strip('/').split('/')[0]
scheme = splits.scheme
url = splits.netloc
self.endpoint = endpoint
self.url = url
# Choose the Triton client based on the communication scheme
if scheme == 'http':
import tritonclient.http as client # noqa
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
config = self.triton_client.get_model_config(endpoint)
else:
import tritonclient.grpc as client # noqa
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
config = self.triton_client.get_model_config(endpoint, as_json=True)['config']
self.InferRequestedOutput = client.InferRequestedOutput
self.InferInput = client.InferInput
type_map = {'TYPE_FP32': np.float32, 'TYPE_FP16': np.float16, 'TYPE_UINT8': np.uint8}
self.input_formats = [x['data_type'] for x in config['input']]
self.np_input_formats = [type_map[x] for x in self.input_formats]
self.input_names = [x['name'] for x in config['input']]
self.output_names = [x['name'] for x in config['output']]
def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]:
"""
Call the model with the given inputs.
Args:
*inputs (List[np.ndarray]): Input data to the model.
Returns:
List[np.ndarray]: Model outputs.
"""
infer_inputs = []
input_format = inputs[0].dtype
for i, x in enumerate(inputs):
if x.dtype != self.np_input_formats[i]:
x = x.astype(self.np_input_formats[i])
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace('TYPE_', ''))
infer_input.set_data_from_numpy(x)
infer_inputs.append(infer_input)
infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]
Loading…
Cancel
Save