Refactor TFLite example. Support FP32, Fp16, INT8 models (#17317)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/17311/head
Mohammed Yasin 2 weeks ago committed by GitHub
parent 788387831a
commit d28caa9a58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      examples/README.md
  2. 65
      examples/YOLOv8-OpenCV-int8-tflite-Python/README.md
  3. 308
      examples/YOLOv8-OpenCV-int8-tflite-Python/main.py
  4. 55
      examples/YOLOv8-TFLite-Python/README.md
  5. 221
      examples/YOLOv8-TFLite-Python/main.py

@ -18,7 +18,7 @@ This directory features a collection of real-world applications and walkthroughs
| [YOLOv8 Region Counter](https://github.com/RizwanMunawar/ultralytics/blob/main/examples/YOLOv8-Region-Counter/yolov8_region_counter.py) | Python | [Muhammad Rizwan Munawar](https://github.com/RizwanMunawar) |
| [YOLOv8 Segmentation ONNXRuntime Python](./YOLOv8-Segmentation-ONNXRuntime-Python) | Python/ONNXRuntime | [jamjamjon](https://github.com/jamjamjon) |
| [YOLOv8 LibTorch CPP](./YOLOv8-LibTorch-CPP-Inference) | C++/LibTorch | [Myyura](https://github.com/Myyura) |
| [YOLOv8 OpenCV INT8 TFLite Python](./YOLOv8-OpenCV-int8-tflite-Python) | Python | [Wamiq Raza](https://github.com/wamiqraza) |
| [YOLOv8 OpenCV INT8 TFLite Python](./YOLOv8-TFLite-Python) | Python | [Wamiq Raza](https://github.com/wamiqraza) |
| [YOLOv8 All Tasks ONNXRuntime Rust](./YOLOv8-ONNXRuntime-Rust) | Rust/ONNXRuntime | [jamjamjon](https://github.com/jamjamjon) |
| [YOLOv8 OpenVINO CPP](./YOLOv8-OpenVINO-CPP-Inference) | C++/OpenVINO | [Erlangga Yudi Pradana](https://github.com/rlggyp) |

@ -1,65 +0,0 @@
# YOLOv8 - Int8-TFLite Runtime
Welcome to the YOLOv8 Int8 TFLite Runtime for efficient and optimized object detection project. This README provides comprehensive instructions for installing and using our YOLOv8 implementation.
## Installation
Ensure a smooth setup by following these steps to install necessary dependencies.
### Installing Required Dependencies
Install all required dependencies with this simple command:
```bash
pip install -r requirements.txt
```
### Installing `tflite-runtime`
To load TFLite models, install the `tflite-runtime` package using:
```bash
pip install tflite-runtime
```
### Installing `tensorflow-gpu` (For NVIDIA GPU Users)
Leverage GPU acceleration with NVIDIA GPUs by installing `tensorflow-gpu`:
```bash
pip install tensorflow-gpu
```
**Note:** Ensure you have compatible GPU drivers installed on your system.
### Installing `tensorflow` (CPU Version)
For CPU usage or non-NVIDIA GPUs, install TensorFlow with:
```bash
pip install tensorflow
```
## Usage
Follow these instructions to run YOLOv8 after successful installation.
Convert the YOLOv8 model to Int8 TFLite format:
```bash
yolo export model=yolov8n.pt imgsz=640 format=tflite int8
```
Locate the Int8 TFLite model in `yolov8n_saved_model`. Choose `best_full_integer_quant` or verify quantization at [Netron](https://netron.app/). Then, execute the following in your terminal:
```bash
python main.py --model yolov8n_full_integer_quant.tflite --img image.jpg --conf-thres 0.5 --iou-thres 0.5
```
Replace `best_full_integer_quant.tflite` with your model file's path, `image.jpg` with your input image, and adjust the confidence (conf-thres) and IoU thresholds (iou-thres) as necessary.
### Output
The output is displayed as annotated images, showcasing the model's detection capabilities:
![image](https://github.com/wamiqraza/Attribute-recognition-and-reidentification-Market1501-dataset/blob/main/img/bus.jpg)

@ -1,308 +0,0 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import argparse
import cv2
import numpy as np
from tflite_runtime import interpreter as tflite
from ultralytics.utils import ASSETS, yaml_load
from ultralytics.utils.checks import check_yaml
# Declare as global variables, can be updated based trained model image size
img_width = 640
img_height = 640
class LetterBox:
"""Resizes and reshapes images while maintaining aspect ratio by adding padding, suitable for YOLO models."""
def __init__(
self, new_shape=(img_width, img_height), auto=False, scaleFill=False, scaleup=True, center=True, stride=32
):
"""Initializes LetterBox with parameters for reshaping and transforming image while maintaining aspect ratio."""
self.new_shape = new_shape
self.auto = auto
self.scaleFill = scaleFill
self.scaleup = scaleup
self.stride = stride
self.center = center # Put the image in the middle or top-left
def __call__(self, labels=None, image=None):
"""Return updated labels and image with added border."""
if labels is None:
labels = {}
img = labels.get("img") if image is None else image
shape = img.shape[:2] # current shape [height, width]
new_shape = labels.pop("rect_shape", self.new_shape)
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not self.scaleup: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if self.auto: # minimum rectangle
dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
elif self.scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
if self.center:
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)) if self.center else 0, int(round(dh + 0.1))
left, right = int(round(dw - 0.1)) if self.center else 0, int(round(dw + 0.1))
img = cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)
) # add border
if labels.get("ratio_pad"):
labels["ratio_pad"] = (labels["ratio_pad"], (left, top)) # for evaluation
if len(labels):
labels = self._update_labels(labels, ratio, dw, dh)
labels["img"] = img
labels["resized_shape"] = new_shape
return labels
else:
return img
def _update_labels(self, labels, ratio, padw, padh):
"""Update labels."""
labels["instances"].convert_bbox(format="xyxy")
labels["instances"].denormalize(*labels["img"].shape[:2][::-1])
labels["instances"].scale(*ratio)
labels["instances"].add_padding(padw, padh)
return labels
class Yolov8TFLite:
"""Class for performing object detection using YOLOv8 model converted to TensorFlow Lite format."""
def __init__(self, tflite_model, input_image, confidence_thres, iou_thres):
"""
Initializes an instance of the Yolov8TFLite class.
Args:
tflite_model: Path to the TFLite model.
input_image: Path to the input image.
confidence_thres: Confidence threshold for filtering detections.
iou_thres: IoU (Intersection over Union) threshold for non-maximum suppression.
"""
self.tflite_model = tflite_model
self.input_image = input_image
self.confidence_thres = confidence_thres
self.iou_thres = iou_thres
# Load the class names from the COCO dataset
self.classes = yaml_load(check_yaml("coco8.yaml"))["names"]
# Generate a color palette for the classes
self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))
def draw_detections(self, img, box, score, class_id):
"""
Draws bounding boxes and labels on the input image based on the detected objects.
Args:
img: The input image to draw detections on.
box: Detected bounding box.
score: Corresponding detection score.
class_id: Class ID for the detected object.
Returns:
None
"""
# Extract the coordinates of the bounding box
x1, y1, w, h = box
# Retrieve the color for the class ID
color = self.color_palette[class_id]
# Draw the bounding box on the image
cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
# Create the label text with class name and score
label = f"{self.classes[class_id]}: {score:.2f}"
# Calculate the dimensions of the label text
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
# Calculate the position of the label text
label_x = x1
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
# Draw a filled rectangle as the background for the label text
cv2.rectangle(
img,
(int(label_x), int(label_y - label_height)),
(int(label_x + label_width), int(label_y + label_height)),
color,
cv2.FILLED,
)
# Draw the label text on the image
cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
def preprocess(self):
"""
Preprocesses the input image before performing inference.
Returns:
image_data: Preprocessed image data ready for inference.
"""
# Read the input image using OpenCV
self.img = cv2.imread(self.input_image)
print("image before", self.img)
# Get the height and width of the input image
self.img_height, self.img_width = self.img.shape[:2]
letterbox = LetterBox(new_shape=[img_width, img_height], auto=False, stride=32)
image = letterbox(image=self.img)
image = [image]
image = np.stack(image)
image = image[..., ::-1].transpose((0, 3, 1, 2))
img = np.ascontiguousarray(image)
# n, h, w, c
image = img.astype(np.float32)
return image / 255
def postprocess(self, input_image, output):
"""
Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs.
Args:
input_image (numpy.ndarray): The input image.
output (numpy.ndarray): The output of the model.
Returns:
numpy.ndarray: The input image with detections drawn on it.
"""
# Transpose predictions outside the loop
output = [np.transpose(pred) for pred in output]
boxes = []
scores = []
class_ids = []
# Vectorize extraction of bounding boxes, scores, and class IDs
for pred in output:
x, y, w, h = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3]
x1 = x - w / 2
y1 = y - h / 2
boxes.extend(np.column_stack([x1, y1, w, h]))
# Argmax and score extraction for all predictions at once
idx = np.argmax(pred[:, 4:], axis=1)
scores.extend(pred[np.arange(pred.shape[0]), idx + 4])
class_ids.extend(idx)
# Precompute gain and pad once
img_height, img_width = input_image.shape[:2]
gain = min(img_width / self.img_width, img_height / self.img_height)
pad = (
round((img_width - self.img_width * gain) / 2 - 0.1),
round((img_height - self.img_height * gain) / 2 - 0.1),
)
# Non-Maximum Suppression (NMS) in one go
indices = cv2.dnn.NMSBoxes(boxes, scores, self.confidence_thres, self.iou_thres)
# Process selected indices
for i in indices.flatten():
box = boxes[i]
box[0] = (box[0] - pad[0]) / gain
box[1] = (box[1] - pad[1]) / gain
box[2] = box[2] / gain
box[3] = box[3] / gain
score = scores[i]
class_id = class_ids[i]
if score > 0.25:
# Draw the detection on the input image
self.draw_detections(input_image, box, score, class_id)
return input_image
def main(self):
"""
Performs inference using a TFLite model and returns the output image with drawn detections.
Returns:
output_img: The output image with drawn detections.
"""
# Create an interpreter for the TFLite model
interpreter = tflite.Interpreter(model_path=self.tflite_model)
self.model = interpreter
interpreter.allocate_tensors()
# Get the model inputs
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Store the shape of the input for later use
input_shape = input_details[0]["shape"]
self.input_width = input_shape[1]
self.input_height = input_shape[2]
# Preprocess the image data
img_data = self.preprocess()
img_data = img_data
# img_data = img_data.cpu().numpy()
# Set the input tensor to the interpreter
print(input_details[0]["index"])
print(img_data.shape)
img_data = img_data.transpose((0, 2, 3, 1))
scale, zero_point = input_details[0]["quantization"]
img_data_int8 = (img_data / scale + zero_point).astype(np.int8)
interpreter.set_tensor(input_details[0]["index"], img_data_int8)
# Run inference
interpreter.invoke()
# Get the output tensor from the interpreter
output = interpreter.get_tensor(output_details[0]["index"])
scale, zero_point = output_details[0]["quantization"]
output = (output.astype(np.float32) - zero_point) * scale
output[:, [0, 2]] *= img_width
output[:, [1, 3]] *= img_height
print(output)
# Perform post-processing on the outputs to obtain output image.
return self.postprocess(self.img, output)
if __name__ == "__main__":
# Create an argument parser to handle command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", type=str, default="yolov8n_full_integer_quant.tflite", help="Input your TFLite model."
)
parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image.")
parser.add_argument("--conf-thres", type=float, default=0.5, help="Confidence threshold")
parser.add_argument("--iou-thres", type=float, default=0.5, help="NMS IoU threshold")
args = parser.parse_args()
# Create an instance of the Yolov8TFLite class with the specified arguments
detection = Yolov8TFLite(args.model, args.img, args.conf_thres, args.iou_thres)
# Perform object detection and obtain the output image
output_image = detection.main()
# Display the output image in a window
cv2.imshow("Output", output_image)
# Wait for a key press to exit
cv2.waitKey(0)

@ -0,0 +1,55 @@
# YOLOv8 - TFLite Runtime
This example shows how to run inference with YOLOv8 TFLite model. It supports FP32, FP16 and INT8 models.
## Installation
### Installing `tflite-runtime`
To load TFLite models, install the `tflite-runtime` package using:
```bash
pip install tflite-runtime
```
### Installing `tensorflow-gpu` (For NVIDIA GPU Users)
Leverage GPU acceleration with NVIDIA GPUs by installing `tensorflow-gpu`:
```bash
pip install tensorflow-gpu
```
**Note:** Ensure you have compatible GPU drivers installed on your system.
### Installing `tensorflow` (CPU Version)
For CPU usage or non-NVIDIA GPUs, install TensorFlow with:
```bash
pip install tensorflow
```
## Usage
Follow these instructions to run YOLOv8 after successful installation.
Convert the YOLOv8 model to TFLite format:
```bash
yolo export model=yolov8n.pt imgsz=640 format=tflite int8
```
Locate the TFLite model in `yolov8n_saved_model`. Then, execute the following in your terminal:
```bash
python main.py --model yolov8n_full_integer_quant.tflite --img image.jpg --conf 0.25 --iou 0.45 --metadata "metadata.yaml"
```
Replace `best_full_integer_quant.tflite` with the TFLite model path, `image.jpg` with the input image path, `metadata.yaml` with the one generated by `ultralytics` during export, and adjust the confidence (conf) and IoU thresholds (iou) as necessary.
### Output
The output would show the detections along with the class labels and confidences of each detected object.
![image](https://github.com/wamiqraza/Attribute-recognition-and-reidentification-Market1501-dataset/blob/main/img/bus.jpg)

@ -0,0 +1,221 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import argparse
from typing import Tuple, Union
import cv2
import numpy as np
import tensorflow as tf
import yaml
from ultralytics.utils import ASSETS
try:
from tflite_runtime.interpreter import Interpreter
except ImportError:
import tensorflow as tf
Interpreter = tf.lite.Interpreter
class YOLOv8TFLite:
"""
YOLOv8TFLite.
A class for performing object detection using the YOLOv8 model with TensorFlow Lite.
Attributes:
model (str): Path to the TensorFlow Lite model file.
conf (float): Confidence threshold for filtering detections.
iou (float): Intersection over Union threshold for non-maximum suppression.
metadata (Optional[str]): Path to the metadata file, if any.
Methods:
detect(img_path: str) -> np.ndarray:
Performs inference and returns the output image with drawn detections.
"""
def __init__(self, model: str, conf: float = 0.25, iou: float = 0.45, metadata: Union[str, None] = None):
"""
Initializes an instance of the YOLOv8TFLite class.
Args:
model (str): Path to the TFLite model.
conf (float, optional): Confidence threshold for filtering detections. Defaults to 0.25.
iou (float, optional): IoU (Intersection over Union) threshold for non-maximum suppression. Defaults to 0.45.
metadata (Union[str, None], optional): Path to the metadata file or None if not used. Defaults to None.
"""
self.conf = conf
self.iou = iou
if metadata is None:
self.classes = {i: i for i in range(1000)}
else:
with open(metadata) as f:
self.classes = yaml.safe_load(f)["names"]
np.random.seed(42)
self.color_palette = np.random.uniform(128, 255, size=(len(self.classes), 3))
self.model = Interpreter(model_path=model)
self.model.allocate_tensors()
input_details = self.model.get_input_details()[0]
self.in_width, self.in_height = input_details["shape"][1:3]
self.in_index = input_details["index"]
self.in_scale, self.in_zero_point = input_details["quantization"]
self.int8 = input_details["dtype"] == np.int8
output_details = self.model.get_output_details()[0]
self.out_index = output_details["index"]
self.out_scale, self.out_zero_point = output_details["quantization"]
def letterbox(self, img: np.ndarray, new_shape: Tuple = (640, 640)) -> Tuple[np.ndarray, Tuple[float, float]]:
"""Resizes and reshapes images while maintaining aspect ratio by adding padding, suitable for YOLO models."""
shape = img.shape[:2] # current shape [height, width]
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = (new_shape[1] - new_unpad[0]) / 2, (new_shape[0] - new_unpad[1]) / 2 # wh padding
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114))
return img, (top / img.shape[0], left / img.shape[1])
def draw_detections(self, img: np.ndarray, box: np.ndarray, score: np.float32, class_id: int) -> None:
"""
Draws bounding boxes and labels on the input image based on the detected objects.
Args:
img (np.ndarray): The input image to draw detections on.
box (np.ndarray): Detected bounding box in the format [x1, y1, width, height].
score (np.float32): Corresponding detection score.
class_id (int): Class ID for the detected object.
Returns:
None
"""
x1, y1, w, h = box
color = self.color_palette[class_id]
cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
label = f"{self.classes[class_id]}: {score:.2f}"
(label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
label_x = x1
label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
cv2.rectangle(
img,
(int(label_x), int(label_y - label_height)),
(int(label_x + label_width), int(label_y + label_height)),
color,
cv2.FILLED,
)
cv2.putText(img, label, (int(label_x), int(label_y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
def preprocess(self, img: np.ndarray) -> Tuple[np.ndarray, Tuple[float, float]]:
"""
Preprocesses the input image before performing inference.
Args:
img (np.ndarray): The input image to be preprocessed.
Returns:
Tuple[np.ndarray, Tuple[float, float]]: A tuple containing:
- The preprocessed image (np.ndarray).
- A tuple of two float values representing the padding applied (top/bottom, left/right).
"""
img, pad = self.letterbox(img, (self.in_width, self.in_height))
img = img[..., ::-1][None] # N,H,W,C for TFLite
img = np.ascontiguousarray(img)
img = img.astype(np.float32)
return img / 255, pad
def postprocess(self, img: np.ndarray, outputs: np.ndarray, pad: Tuple[float, float]) -> np.ndarray:
"""
Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs.
Args:
img (numpy.ndarray): The input image.
outputs (numpy.ndarray): The output of the model.
pad (Tuple[float, float]): Padding used by letterbox.
Returns:
numpy.ndarray: The input image with detections drawn on it.
"""
outputs[:, 0] -= pad[1]
outputs[:, 1] -= pad[0]
outputs[:, :4] *= max(img.shape)
outputs = outputs.transpose(0, 2, 1)
outputs[..., 0] -= outputs[..., 2] / 2
outputs[..., 1] -= outputs[..., 3] / 2
for out in outputs:
scores = out[:, 4:].max(-1)
keep = scores > self.conf
boxes = out[keep, :4]
scores = scores[keep]
class_ids = out[keep, 4:].argmax(-1)
indices = cv2.dnn.NMSBoxes(boxes, scores, self.conf, self.iou).flatten()
[self.draw_detections(img, boxes[i], scores[i], class_ids[i]) for i in indices]
return img
def detect(self, img_path: str) -> np.ndarray:
"""
Performs inference using a TFLite model and returns the output image with drawn detections.
Args:
img_path (str): The path to the input image file.
Returns:
np.ndarray: The output image with drawn detections.
"""
img = cv2.imread(img_path)
x, pad = self.preprocess(img)
if self.int8:
x = (x / self.in_scale + self.in_zero_point).astype(np.int8)
self.model.set_tensor(self.in_index, x)
self.model.invoke()
y = self.model.get_tensor(self.out_index)
if self.int8:
y = (y.astype(np.float32) - self.out_zero_point) * self.out_scale
return self.postprocess(img, y, pad)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="yolov8n_saved_model/yolov8n_full_integer_quant.tflite",
help="Path to TFLite model.",
)
parser.add_argument("--img", type=str, default=str(ASSETS / "bus.jpg"), help="Path to input image")
parser.add_argument("--conf", type=float, default=0.25, help="Confidence threshold")
parser.add_argument("--iou", type=float, default=0.45, help="NMS IoU threshold")
parser.add_argument("--metadata", type=str, default="yolov8n_saved_model/metadata.yaml", help="Metadata yaml")
args = parser.parse_args()
detector = YOLOv8TFLite(args.model, args.conf, args.iou, args.metadata)
result = detector.detect(str(ASSETS / "bus.jpg"))[..., ::-1]
cv2.imshow("Output", result)
cv2.waitKey(0)
Loading…
Cancel
Save