From dc15242cbdc56887a68e78c864119f6918bca0de Mon Sep 17 00:00:00 2001 From: memorylorry Date: Sun, 25 Aug 2024 07:05:52 +0800 Subject: [PATCH 1/2] Fix YOLOv8 C++ ONNXRuntime transpose op (#15779) Co-authored-by: Glenn Jocher --- examples/YOLOv8-ONNXRuntime-CPP/inference.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/YOLOv8-ONNXRuntime-CPP/inference.cpp b/examples/YOLOv8-ONNXRuntime-CPP/inference.cpp index 5154a8303..2ee993eed 100644 --- a/examples/YOLOv8-ONNXRuntime-CPP/inference.cpp +++ b/examples/YOLOv8-ONNXRuntime-CPP/inference.cpp @@ -221,8 +221,8 @@ char* YOLO_V8::TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std:: case YOLO_DETECT_V8: case YOLO_DETECT_V8_HALF: { - int strideNum = outputNodeDims[1];//8400 - int signalResultNum = outputNodeDims[2];//84 + int signalResultNum = outputNodeDims[1];//84 + int strideNum = outputNodeDims[2];//8400 std::vector class_ids; std::vector confidences; std::vector boxes; @@ -230,18 +230,18 @@ char* YOLO_V8::TensorProcess(clock_t& starttime_1, cv::Mat& iImg, N& blob, std:: if (modelType == YOLO_DETECT_V8) { // FP32 - rawData = cv::Mat(strideNum, signalResultNum, CV_32F, output); + rawData = cv::Mat(signalResultNum, strideNum, CV_32F, output); } else { // FP16 - rawData = cv::Mat(strideNum, signalResultNum, CV_16F, output); + rawData = cv::Mat(signalResultNum, strideNum, CV_16F, output); rawData.convertTo(rawData, CV_32F); } //Note: //ultralytics add transpose operator to the output of yolov8 model.which make yolov8/v5/v7 has same shape //https://github.com/ultralytics/assets/releases/download/v8.2.0/yolov8n.pt - //rowData = rowData.t(); + rawData = rawData.t(); float* data = (float*)rawData.data; From b2604c7df12ab867e7942843f1bcec665fad66ef Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 25 Aug 2024 23:43:35 +0800 Subject: [PATCH 2/2] `ultralytics 8.2.82` YOLOv10 CoreML, Edge TPU and TF.js export support (#15796) Signed-off-by: UltralyticsAssistant Co-authored-by: UltralyticsAssistant Co-authored-by: Ryan Hirasaki <4690732+RyanHir@users.noreply.github.com> --- docs/en/models/yolov10.md | 28 +++++++++++----------- ultralytics/__init__.py | 2 +- ultralytics/nn/modules/head.py | 41 +++++++++++---------------------- ultralytics/utils/benchmarks.py | 5 ++-- 4 files changed, 31 insertions(+), 45 deletions(-) diff --git a/docs/en/models/yolov10.md b/docs/en/models/yolov10.md index 98a164a42..e8e4a2862 100644 --- a/docs/en/models/yolov10.md +++ b/docs/en/models/yolov10.md @@ -202,20 +202,20 @@ The YOLOv10 models series offers a range of models, each optimized for high-perf Due to the new operations introduced with YOLOv10, not all export formats provided by Ultralytics are currently supported. The following table outlines which formats have been successfully converted using Ultralytics for YOLOv10. Feel free to open a pull request if you're able to [provide a contribution change](../help/contributing.md) for adding export support of additional formats for YOLOv10. -| Export Format | Supported | -| ------------------------------------------------- | --------- | -| [TorchScript](../integrations/torchscript.md) | ✅ | -| [ONNX](../integrations/onnx.md) | ✅ | -| [OpenVINO](../integrations/openvino.md) | ✅ | -| [TensorRT](../integrations/tensorrt.md) | ✅ | -| [CoreML](../integrations/coreml.md) | ❌ | -| [TF SavedModel](../integrations/tf-savedmodel.md) | ✅ | -| [TF GraphDef](../integrations/tf-graphdef.md) | ✅ | -| [TF Lite](../integrations/tflite.md) | ✅ | -| [TF Edge TPU](../integrations/edge-tpu.md) | ❌ | -| [TF.js](../integrations/tfjs.md) | ❌ | -| [PaddlePaddle](../integrations/paddlepaddle.md) | ❌ | -| [NCNN](../integrations/ncnn.md) | ❌ | +| Export Format | Export Support | Exported Model Inference | Notes | +| ------------------------------------------------- | -------------- | ------------------------ | ------------------------------------------- | +| [TorchScript](../integrations/torchscript.md) | ✅ | ✅ | Standard PyTorch model format. | +| [ONNX](../integrations/onnx.md) | ✅ | ✅ | Widely supported for deployment. | +| [OpenVINO](../integrations/openvino.md) | ✅ | ✅ | Optimized for Intel hardware. | +| [TensorRT](../integrations/tensorrt.md) | ✅ | ✅ | Optimized for NVIDIA GPUs. | +| [CoreML](../integrations/coreml.md) | ✅ | ✅ | Limited to Apple devices. | +| [TF SavedModel](../integrations/tf-savedmodel.md) | ✅ | ✅ | TensorFlow's standard model format. | +| [TF GraphDef](../integrations/tf-graphdef.md) | ✅ | ✅ | Legacy TensorFlow format. | +| [TF Lite](../integrations/tflite.md) | ✅ | ✅ | Optimized for mobile and embedded. | +| [TF Edge TPU](../integrations/edge-tpu.md) | ✅ | ✅ | Specific to Google's Edge TPU devices. | +| [TF.js](../integrations/tfjs.md) | ✅ | ✅ | JavaScript environment for browser use. | +| [PaddlePaddle](../integrations/paddlepaddle.md) | ❌ | ❌ | Popular in China; less global support. | +| [NCNN](../integrations/ncnn.md) | ✅ | ❌ | Layer `torch.topk` not exists or registered | ## Conclusion diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 5d7906d07..a4e8dd21e 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.2.81" +__version__ = "8.2.82" import os diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 06812e181..ed0b90f80 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn from torch.nn.init import constant_, xavier_uniform_ -from ultralytics.utils import MACOS from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto @@ -133,38 +132,26 @@ class Detect(nn.Module): @staticmethod def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80): """ - Post-processes the predictions obtained from a YOLOv10 model. + Post-processes YOLO model predictions. Args: - preds (torch.Tensor): The predictions obtained from the model. It should have a shape of (batch_size, num_boxes, 4 + num_classes). - max_det (int): The maximum number of detections to keep. - nc (int, optional): The number of classes. Defaults to 80. + preds (torch.Tensor): Raw predictions with shape (batch_size, num_anchors, 4 + nc) with last dimension + format [x, y, w, h, class_probs]. + max_det (int): Maximum detections per image. + nc (int, optional): Number of classes. Default: 80. Returns: - (torch.Tensor): The post-processed predictions with shape (batch_size, max_det, 6), - including bounding boxes, scores and cls. + (torch.Tensor): Processed predictions with shape (batch_size, min(max_det, num_anchors), 6) and last + dimension format [x, y, w, h, max_class_prob, class_index]. """ - assert 4 + nc == preds.shape[-1] + batch_size, anchors, predictions = preds.shape # i.e. shape(16,8400,84) boxes, scores = preds.split([4, nc], dim=-1) - max_scores = scores.amax(dim=-1) - max_scores, index = torch.topk(max_scores, min(max_det, max_scores.shape[1]), axis=-1) - index = index.unsqueeze(-1) - boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1])) - scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1])) - - # NOTE: simplify result but slightly lower mAP - # scores, labels = scores.max(dim=-1) - # return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) - - scores, index = torch.topk(scores.flatten(1), max_det, axis=-1) - labels = index % nc - index = index // nc - # Set int64 dtype for MPS and CoreML compatibility to avoid 'gather_along_axis' ops error - if MACOS: - index = index.to(torch.int64) - boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) - - return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1) + index = scores.amax(dim=-1).topk(min(max_det, anchors))[1].unsqueeze(-1) + boxes = boxes.gather(dim=1, index=index.repeat(1, 1, 4)) + scores = scores.gather(dim=1, index=index.repeat(1, 1, nc)) + scores, index = scores.flatten(1).topk(max_det) + i = torch.arange(batch_size)[..., None] # batch indices + return torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1) class Segment(Detect): diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py index 0c8a7486f..c33a89064 100644 --- a/ultralytics/utils/benchmarks.py +++ b/ultralytics/utils/benchmarks.py @@ -97,20 +97,17 @@ def benchmark( assert MACOS or LINUX, "CoreML and TF.js export only supported on macOS and Linux" assert not IS_RASPBERRYPI, "CoreML and TF.js export not supported on Raspberry Pi" assert not IS_JETSON, "CoreML and TF.js export not supported on NVIDIA Jetson" - assert not is_end2end, "End-to-end models not supported by CoreML and TF.js yet" if i in {3, 5}: # CoreML and OpenVINO assert not IS_PYTHON_3_12, "CoreML and OpenVINO not supported on Python 3.12" if i in {6, 7, 8}: # TF SavedModel, TF GraphDef, and TFLite assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet" if i in {9, 10}: # TF EdgeTPU and TF.js assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet" - assert not is_end2end, "End-to-end models not supported by TF EdgeTPU and TF.js yet" if i in {11}: # Paddle assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet" assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet" if i in {12}: # NCNN assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet" - assert not is_end2end, "End-to-end models not supported by NCNN yet" if "cpu" in device.type: assert cpu, "inference not supported on CPU" if "cuda" in device.type: @@ -130,6 +127,8 @@ def benchmark( assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported" assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML + if i in {12}: + assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet" exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half) # Validate