From 6ffd8841fd8cdd86ab1ce1d102997f941f3c88e8 Mon Sep 17 00:00:00 2001 From: Muhammad Rizwan Munawar Date: Wed, 30 Oct 2024 16:36:03 +0500 Subject: [PATCH 1/5] Update notebooks (#17260) Co-authored-by: UltralyticsAssistant --- examples/heatmaps.ipynb | 2 +- examples/object_counting.ipynb | 2 +- examples/object_tracking.ipynb | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/heatmaps.ipynb b/examples/heatmaps.ipynb index c674ad480..d0124df89 100644 --- a/examples/heatmaps.ipynb +++ b/examples/heatmaps.ipynb @@ -112,7 +112,7 @@ "heatmap_obj = solutions.Heatmap(\n", " colormap=cv2.COLORMAP_PARULA, # Color of the heatmap\n", " show=True, # Display the image during processing\n", - " model=yolo11n.pt, # Ultralytics YOLO11 model file\n", + " model=\"yolo11n.pt\", # Ultralytics YOLO11 model file\n", ")\n", "\n", "while cap.isOpened():\n", diff --git a/examples/object_counting.ipynb b/examples/object_counting.ipynb index 50168f262..e742cff6a 100644 --- a/examples/object_counting.ipynb +++ b/examples/object_counting.ipynb @@ -123,7 +123,7 @@ "counter = solutions.ObjectCounter(\n", " show=True, # Display the image during processing\n", " region=line_points, # Region of interest points\n", - " model=yolo11n.pt, # Ultralytics YOLO11 model file\n", + " model=\"yolo11n.pt\", # Ultralytics YOLO11 model file\n", " line_width=2, # Thickness of the lines and bounding boxes\n", ")\n", "\n", diff --git a/examples/object_tracking.ipynb b/examples/object_tracking.ipynb index 7691fce9c..cc4d03add 100644 --- a/examples/object_tracking.ipynb +++ b/examples/object_tracking.ipynb @@ -176,7 +176,7 @@ "\n", " # Annotate each mask with its corresponding tracking ID and color\n", " for mask, track_id in zip(masks, track_ids):\n", - " annotator.seg_bbox(mask=mask, mask_color=colors(track_id, True), track_label=str(track_id))\n", + " annotator.seg_bbox(mask=mask, mask_color=colors(int(track_id), True), label=str(track_id))\n", "\n", " # Write the annotated frame to the output video\n", " out.write(im0)\n", From e798dbf52e02c367b657daea85e90bf49c340f3c Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Wed, 30 Oct 2024 19:37:56 +0800 Subject: [PATCH 2/5] Fix missing argument (#17253) --- ultralytics/models/sam/modules/sam.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ultralytics/models/sam/modules/sam.py b/ultralytics/models/sam/modules/sam.py index 562314b2b..7bfd71661 100644 --- a/ultralytics/models/sam/modules/sam.py +++ b/ultralytics/models/sam/modules/sam.py @@ -854,6 +854,7 @@ class SAM2Model(torch.nn.Module): mask_inputs, output_dict, num_frames, + track_in_reverse, prev_sam_mask_logits, ): """Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" From b8c90baffee06b7b162cb29bd94383a693a42744 Mon Sep 17 00:00:00 2001 From: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com> Date: Wed, 30 Oct 2024 19:38:28 +0800 Subject: [PATCH 3/5] Update triton-inference-server.md (#17252) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- docs/en/guides/triton-inference-server.md | 43 ++++++++++++++--------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/docs/en/guides/triton-inference-server.md b/docs/en/guides/triton-inference-server.md index 09f7516b1..0151cc078 100644 --- a/docs/en/guides/triton-inference-server.md +++ b/docs/en/guides/triton-inference-server.md @@ -83,25 +83,34 @@ The Triton Model Repository is a storage location where Triton can access and lo # (Optional) Enable TensorRT for GPU inference # First run will be slow due to TensorRT engine conversion - import json - - data = { - "optimization": { - "execution_accelerators": { - "gpu_execution_accelerator": [ - { - "name": "tensorrt", - "parameters": {"key": "precision_mode", "value": "FP16"}, - "parameters": {"key": "max_workspace_size_bytes", "value": "3221225472"}, - "parameters": {"key": "trt_engine_cache_enable", "value": "1"}, - } - ] - } + data = """ + optimization { + execution_accelerators { + gpu_execution_accelerator { + name: "tensorrt" + parameters { + key: "precision_mode" + value: "FP16" + } + parameters { + key: "max_workspace_size_bytes" + value: "3221225472" + } + parameters { + key: "trt_engine_cache_enable" + value: "1" + } + parameters { + key: "trt_engine_cache_path" + value: "/models/yolo/1" + } } + } } + """ with open(triton_model_path / "config.pbtxt", "w") as f: - json.dump(data, f, indent=4) + f.write(data) ``` ## Running Triton Inference Server @@ -124,7 +133,7 @@ subprocess.call(f"docker pull {tag}", shell=True) # Run the Triton server and capture the container ID container_id = ( subprocess.check_output( - f"docker run -d --rm -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models", + f"docker run -d --rm --gpus 0 -v {triton_repo_path}:/models -p 8000:8000 {tag} tritonserver --model-repository=/models", shell=True, ) .decode("utf-8") @@ -215,7 +224,7 @@ Setting up [Ultralytics YOLO11](https://docs.ultralytics.com/models/yolov8/) wit container_id = ( subprocess.check_output( - f"docker run -d --rm -v {triton_repo_path}/models -p 8000:8000 {tag} tritonserver --model-repository=/models", + f"docker run -d --rm --gpus 0 -v {triton_repo_path}/models -p 8000:8000 {tag} tritonserver --model-repository=/models", shell=True, ) .decode("utf-8") From 11b419434487a894656fe46d819d2ae868f25cc1 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 30 Oct 2024 13:42:39 +0100 Subject: [PATCH 4/5] Disable Ray tests (#17266) Co-authored-by: UltralyticsAssistant --- .github/workflows/ci.yaml | 4 ++-- ultralytics/utils/tuner.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6963156ce..381e92c4c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -184,7 +184,7 @@ jobs: torch="torch==1.8.0 torchvision==0.9.0" fi if [[ "${{ github.event_name }}" =~ ^(schedule|workflow_dispatch)$ ]]; then - slow="pycocotools mlflow ray[tune]" + slow="pycocotools mlflow" fi pip install -e ".[export]" $torch $slow pytest-cov --extra-index-url https://download.pytorch.org/whl/cpu - name: Check environment @@ -247,7 +247,7 @@ jobs: - name: Install requirements run: | python -m pip install --upgrade pip wheel - pip install -e ".[export]" pytest mlflow pycocotools "ray[tune]" + pip install -e ".[export]" pytest mlflow pycocotools - name: Check environment run: | yolo checks diff --git a/ultralytics/utils/tuner.py b/ultralytics/utils/tuner.py index e611fa9af..165c788a7 100644 --- a/ultralytics/utils/tuner.py +++ b/ultralytics/utils/tuner.py @@ -1,12 +1,16 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license - from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_save_dir from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks def run_ray_tune( - model, space: dict = None, grace_period: int = 10, gpu_per_trial: int = None, max_samples: int = 10, **train_args + model, + space: dict = None, + grace_period: int = 10, + gpu_per_trial: int = None, + max_samples: int = 10, + **train_args, ): """ Runs hyperparameter tuning using Ray Tune. @@ -38,7 +42,7 @@ def run_ray_tune( train_args = {} try: - checks.check_requirements(("ray[tune]", "numpy<2.0.0")) + checks.check_requirements("ray[tune]") import ray from ray import tune From 9c72d94ba4a83e8911595341b0c3a1d30bbbe8a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=8F=AC=E5=BE=B7?= <8401806+wangzhaode@users.noreply.github.com> Date: Wed, 30 Oct 2024 20:59:48 +0800 Subject: [PATCH 5/5] `ultralytics 8.3.25` Alibaba MNN export and predict support (#16802) Co-authored-by: UltralyticsAssistant Co-authored-by: Francesco Mattioli Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Laughing-q <1185102784@qq.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> Co-authored-by: Glenn Jocher --- .gitignore | 1 + docs/en/guides/model-deployment-options.md | 35 ++- docs/en/integrations/index.md | 6 +- docs/en/integrations/mnn.md | 342 +++++++++++++++++++++ docs/en/macros/export-table.md | 1 + docs/mkdocs_github_authors.yaml | 3 + mkdocs.yml | 3 +- tests/test_exports.py | 7 + ultralytics/__init__.py | 2 +- ultralytics/engine/exporter.py | 37 ++- ultralytics/engine/predictor.py | 1 + ultralytics/engine/validator.py | 1 + ultralytics/nn/autobackend.py | 58 +++- ultralytics/utils/benchmarks.py | 7 +- 14 files changed, 465 insertions(+), 39 deletions(-) create mode 100644 docs/en/integrations/mnn.md diff --git a/.gitignore b/.gitignore index 5cc365b4d..4e0f0845b 100644 --- a/.gitignore +++ b/.gitignore @@ -157,6 +157,7 @@ weights/ *.torchscript *.tflite *.h5 +*.mnn *_saved_model/ *_web_model/ *_openvino_model/ diff --git a/docs/en/guides/model-deployment-options.md b/docs/en/guides/model-deployment-options.md index a9efee17c..1b97e31e4 100644 --- a/docs/en/guides/model-deployment-options.md +++ b/docs/en/guides/model-deployment-options.md @@ -258,25 +258,30 @@ NCNN is a high-performance neural network inference framework optimized for the - **Hardware Acceleration**: Tailored for ARM CPUs and GPUs, with specific optimizations for these architectures. +#### MNN + +MNN is a highly efficient and lightweight deep learning framework. It supports inference and training of deep learning models and has industry-leading performance for inference and training on-device. In addition, MNN is also used on embedded devices, such as IoT. + ## Comparative Analysis of YOLO11 Deployment Options The following table provides a snapshot of the various deployment options available for YOLO11 models, helping you to assess which may best fit your project needs based on several critical criteria. For an in-depth look at each deployment option's format, please see the [Ultralytics documentation page on export formats](../modes/export.md#export-formats). -| Deployment Option | Performance Benchmarks | Compatibility and Integration | Community Support and Ecosystem | Case Studies | Maintenance and Updates | Security Considerations | Hardware Acceleration | -| ----------------- | ----------------------------------------------- | ---------------------------------------------- | --------------------------------------------- | ------------------------------------------ | ------------------------------------------- | ------------------------------------------------- | ---------------------------------- | -| PyTorch | Good flexibility; may trade off raw performance | Excellent with Python libraries | Extensive resources and community | Research and prototypes | Regular, active development | Dependent on deployment environment | CUDA support for GPU acceleration | -| TorchScript | Better for production than PyTorch | Smooth transition from PyTorch to C++ | Specialized but narrower than PyTorch | Industry where Python is a bottleneck | Consistent updates with PyTorch | Improved security without full Python | Inherits CUDA support from PyTorch | -| ONNX | Variable depending on runtime | High across different frameworks | Broad ecosystem, supported by many orgs | Flexibility across ML frameworks | Regular updates for new operations | Ensure secure conversion and deployment practices | Various hardware optimizations | -| OpenVINO | Optimized for Intel hardware | Best within Intel ecosystem | Solid in computer vision domain | IoT and edge with Intel hardware | Regular updates for Intel hardware | Robust features for sensitive applications | Tailored for Intel hardware | -| TensorRT | Top-tier on NVIDIA GPUs | Best for NVIDIA hardware | Strong network through NVIDIA | Real-time video and image inference | Frequent updates for new GPUs | Emphasis on security | Designed for NVIDIA GPUs | -| CoreML | Optimized for on-device Apple hardware | Exclusive to Apple ecosystem | Strong Apple and developer support | On-device ML on Apple products | Regular Apple updates | Focus on privacy and security | Apple neural engine and GPU | -| TF SavedModel | Scalable in server environments | Wide compatibility in TensorFlow ecosystem | Large support due to TensorFlow popularity | Serving models at scale | Regular updates by Google and community | Robust features for enterprise | Various hardware accelerations | -| TF GraphDef | Stable for static computation graphs | Integrates well with TensorFlow infrastructure | Resources for optimizing static graphs | Scenarios requiring static graphs | Updates alongside TensorFlow core | Established TensorFlow security practices | TensorFlow acceleration options | -| TF Lite | Speed and efficiency on mobile/embedded | Wide range of device support | Robust community, Google backed | Mobile applications with minimal footprint | Latest features for mobile | Secure environment on end-user devices | GPU and DSP among others | -| TF Edge TPU | Optimized for Google's Edge TPU hardware | Exclusive to Edge TPU devices | Growing with Google and third-party resources | IoT devices requiring real-time processing | Improvements for new Edge TPU hardware | Google's robust IoT security | Custom-designed for Google Coral | -| TF.js | Reasonable in-browser performance | High with web technologies | Web and Node.js developers support | Interactive web applications | TensorFlow team and community contributions | Web platform security model | Enhanced with WebGL and other APIs | -| PaddlePaddle | Competitive, easy to use and scalable | Baidu ecosystem, wide application support | Rapidly growing, especially in China | Chinese market and language processing | Focus on Chinese AI applications | Emphasizes data privacy and security | Including Baidu's Kunlun chips | -| NCNN | Optimized for mobile ARM-based devices | Mobile and embedded ARM systems | Niche but active mobile/embedded ML community | Android and ARM systems efficiency | High performance maintenance on ARM | On-device security advantages | ARM CPUs and GPUs optimizations | +| Deployment Option | Performance Benchmarks | Compatibility and Integration | Community Support and Ecosystem | Case Studies | Maintenance and Updates | Security Considerations | Hardware Acceleration | +| ----------------- | ----------------------------------------------- | ---------------------------------------------- | --------------------------------------------- | ------------------------------------------ | ---------------------------------------------- | ------------------------------------------------- | ---------------------------------- | +| PyTorch | Good flexibility; may trade off raw performance | Excellent with Python libraries | Extensive resources and community | Research and prototypes | Regular, active development | Dependent on deployment environment | CUDA support for GPU acceleration | +| TorchScript | Better for production than PyTorch | Smooth transition from PyTorch to C++ | Specialized but narrower than PyTorch | Industry where Python is a bottleneck | Consistent updates with PyTorch | Improved security without full Python | Inherits CUDA support from PyTorch | +| ONNX | Variable depending on runtime | High across different frameworks | Broad ecosystem, supported by many orgs | Flexibility across ML frameworks | Regular updates for new operations | Ensure secure conversion and deployment practices | Various hardware optimizations | +| OpenVINO | Optimized for Intel hardware | Best within Intel ecosystem | Solid in computer vision domain | IoT and edge with Intel hardware | Regular updates for Intel hardware | Robust features for sensitive applications | Tailored for Intel hardware | +| TensorRT | Top-tier on NVIDIA GPUs | Best for NVIDIA hardware | Strong network through NVIDIA | Real-time video and image inference | Frequent updates for new GPUs | Emphasis on security | Designed for NVIDIA GPUs | +| CoreML | Optimized for on-device Apple hardware | Exclusive to Apple ecosystem | Strong Apple and developer support | On-device ML on Apple products | Regular Apple updates | Focus on privacy and security | Apple neural engine and GPU | +| TF SavedModel | Scalable in server environments | Wide compatibility in TensorFlow ecosystem | Large support due to TensorFlow popularity | Serving models at scale | Regular updates by Google and community | Robust features for enterprise | Various hardware accelerations | +| TF GraphDef | Stable for static computation graphs | Integrates well with TensorFlow infrastructure | Resources for optimizing static graphs | Scenarios requiring static graphs | Updates alongside TensorFlow core | Established TensorFlow security practices | TensorFlow acceleration options | +| TF Lite | Speed and efficiency on mobile/embedded | Wide range of device support | Robust community, Google backed | Mobile applications with minimal footprint | Latest features for mobile | Secure environment on end-user devices | GPU and DSP among others | +| TF Edge TPU | Optimized for Google's Edge TPU hardware | Exclusive to Edge TPU devices | Growing with Google and third-party resources | IoT devices requiring real-time processing | Improvements for new Edge TPU hardware | Google's robust IoT security | Custom-designed for Google Coral | +| TF.js | Reasonable in-browser performance | High with web technologies | Web and Node.js developers support | Interactive web applications | TensorFlow team and community contributions | Web platform security model | Enhanced with WebGL and other APIs | +| PaddlePaddle | Competitive, easy to use and scalable | Baidu ecosystem, wide application support | Rapidly growing, especially in China | Chinese market and language processing | Focus on Chinese AI applications | Emphasizes data privacy and security | Including Baidu's Kunlun chips | +| MNN | High-performance for mobile devices. | Mobile and embedded ARM systems and X86-64 CPU | Mobile/embedded ML community | Moblile systems efficiency | High performance maintenance on Mobile Devices | On-device security advantages | ARM CPUs and GPUs optimizations | +| NCNN | Optimized for mobile ARM-based devices | Mobile and embedded ARM systems | Niche but active mobile/embedded ML community | Android and ARM systems efficiency | High performance maintenance on ARM | On-device security advantages | ARM CPUs and GPUs optimizations | This comparative analysis gives you a high-level overview. For deployment, it's essential to consider the specific requirements and constraints of your project, and consult the detailed documentation and resources available for each option. diff --git a/docs/en/integrations/index.md b/docs/en/integrations/index.md index bdb8b9c90..f2859e838 100644 --- a/docs/en/integrations/index.md +++ b/docs/en/integrations/index.md @@ -57,6 +57,8 @@ Welcome to the Ultralytics Integrations page! This page provides an overview of - [Weights & Biases (W&B)](weights-biases.md): Monitor experiments, visualize metrics, and foster reproducibility and collaboration on Ultralytics projects. +- [VS Code](vscode.md): An extension for VS Code that provides code snippets for accelerating development workflows with Ultralytics and also for anyone looking for examples to help learn or get started with Ultralytics. + ## Deployment Integrations - [CoreML](coreml.md): CoreML, developed by [Apple](https://www.apple.com/), is a framework designed for efficiently integrating machine learning models into applications across iOS, macOS, watchOS, and tvOS, using Apple's hardware for effective and secure [model deployment](https://www.ultralytics.com/glossary/model-deployment). @@ -65,6 +67,8 @@ Welcome to the Ultralytics Integrations page! This page provides an overview of - [NCNN](ncnn.md): Developed by [Tencent](http://www.tencent.com/), NCNN is an efficient [neural network](https://www.ultralytics.com/glossary/neural-network-nn) inference framework tailored for mobile devices. It enables direct deployment of AI models into apps, optimizing performance across various mobile platforms. +- [MNN](mnn.md): Developed by [Alibaba](https://www.alibabagroup.com/), MNN is a highly efficient and lightweight deep learning framework. It supports inference and training of deep learning models and has industry-leading performance for inference and training on-device. + - [Neural Magic](neural-magic.md): Leverage Quantization Aware Training (QAT) and pruning techniques to optimize Ultralytics models for superior performance and leaner size. - [ONNX](onnx.md): An open-source format created by [Microsoft](https://www.microsoft.com/) for facilitating the transfer of AI models between various frameworks, enhancing the versatility and deployment flexibility of Ultralytics models. @@ -87,8 +91,6 @@ Welcome to the Ultralytics Integrations page! This page provides an overview of - [TorchScript](torchscript.md): Developed as part of the [PyTorch](https://pytorch.org/) framework, TorchScript enables efficient execution and deployment of machine learning models in various production environments without the need for Python dependencies. -- [VS Code](vscode.md): An extension for VS Code that provides code snippets for accelerating development workflows with Ultralytics and also for anyone looking for examples to help learn or get started with Ultralytics. - ### Export Formats We also support a variety of model export formats for deployment in different environments. Here are the available formats: diff --git a/docs/en/integrations/mnn.md b/docs/en/integrations/mnn.md new file mode 100644 index 000000000..591937361 --- /dev/null +++ b/docs/en/integrations/mnn.md @@ -0,0 +1,342 @@ +--- +comments: true +description: Optimize YOLO11 models for mobile and embedded devices by exporting to MNN format. +keywords: Ultralytics, YOLO11, MNN, model export, machine learning, deployment, mobile, embedded systems, deep learning, AI models +--- + +# MNN Export for YOLO11 Models and Deploy + +## MNN + +

+ MNN architecture +

+ +[MNN](https://github.com/alibaba/MNN) is a highly efficient and lightweight deep learning framework. It supports inference and training of deep learning models and has industry-leading performance for inference and training on-device. At present, MNN has been integrated into more than 30 apps of Alibaba Inc, such as Taobao, Tmall, Youku, DingTalk, Xianyu, etc., covering more than 70 usage scenarios such as live broadcast, short video capture, search recommendation, product searching by image, interactive marketing, equity distribution, security risk control. In addition, MNN is also used on embedded devices, such as IoT. + +## Export to MNN: Converting Your YOLO11 Model + +You can expand model compatibility and deployment flexibility by converting YOLO11 models to MNN format. + +### Installation + +To install the required packages, run: + +!!! tip "Installation" + + === "CLI" + + ```bash + # Install the required package for YOLO11 and MNN + pip install ultralytics + pip install MNN + ``` + +### Usage + +Before diving into the usage instructions, it's important to note that while all [Ultralytics YOLO11 models](../models/index.md) are available for exporting, you can ensure that the model you select supports export functionality [here](../modes/export.md). + +!!! example "Usage" + + === "Python" + + ```python + from ultralytics import YOLO + + # Load the YOLO11 model + model = YOLO("yolo11n.pt") + + # Export the model to MNN format + model.export(format="mnn") # creates 'yolo11n.mnn' + + # Load the exported MNN model + mnn_model = YOLO("yolo11n.mnn") + + # Run inference + results = mnn_model("https://ultralytics.com/images/bus.jpg") + ``` + + === "CLI" + + ```bash + # Export a YOLO11n PyTorch model to MNN format + yolo export model=yolo11n.pt format=mnn # creates 'yolo11n.mnn' + + # Run inference with the exported model + yolo predict model='yolo11n.mnn' source='https://ultralytics.com/images/bus.jpg' + ``` + +For more details about supported export options, visit the [Ultralytics documentation page on deployment options](../guides/model-deployment-options.md). + +### MNN-Only Inference + +A function that relies solely on MNN for YOLO11 inference and preprocessing is implemented, providing both Python and C++ versions for easy deployment in any scenario. + +!!! example "MNN" + + === "Python" + + ```python + import argparse + + import MNN + import MNN.cv as cv2 + import MNN.numpy as np + + + def inference(model, img, precision, backend, thread): + config = {} + config["precision"] = precision + config["backend"] = backend + config["numThread"] = thread + rt = MNN.nn.create_runtime_manager((config,)) + # net = MNN.nn.load_module_from_file(model, ['images'], ['output0'], runtime_manager=rt) + net = MNN.nn.load_module_from_file(model, [], [], runtime_manager=rt) + original_image = cv2.imread(img) + ih, iw, _ = original_image.shape + length = max((ih, iw)) + scale = length / 640 + image = np.pad(original_image, [[0, length - ih], [0, length - iw], [0, 0]], "constant") + image = cv2.resize( + image, (640, 640), 0.0, 0.0, cv2.INTER_LINEAR, -1, [0.0, 0.0, 0.0], [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0] + ) + input_var = np.expand_dims(image, 0) + input_var = MNN.expr.convert(input_var, MNN.expr.NC4HW4) + output_var = net.forward(input_var) + output_var = MNN.expr.convert(output_var, MNN.expr.NCHW) + output_var = output_var.squeeze() + # output_var shape: [84, 8400]; 84 means: [cx, cy, w, h, prob * 80] + cx = output_var[0] + cy = output_var[1] + w = output_var[2] + h = output_var[3] + probs = output_var[4:] + # [cx, cy, w, h] -> [y0, x0, y1, x1] + x0 = cx - w * 0.5 + y0 = cy - h * 0.5 + x1 = cx + w * 0.5 + y1 = cy + h * 0.5 + boxes = np.stack([x0, y0, x1, y1], axis=1) + # get max prob and idx + scores = np.max(probs, 0) + class_ids = np.argmax(probs, 0) + result_ids = MNN.expr.nms(boxes, scores, 100, 0.45, 0.25) + print(result_ids.shape) + # nms result box, score, ids + result_boxes = boxes[result_ids] + result_scores = scores[result_ids] + result_class_ids = class_ids[result_ids] + for i in range(len(result_boxes)): + x0, y0, x1, y1 = result_boxes[i].read_as_tuple() + y0 = int(y0 * scale) + y1 = int(y1 * scale) + x0 = int(x0 * scale) + x1 = int(x1 * scale) + print(result_class_ids[i]) + cv2.rectangle(original_image, (x0, y0), (x1, y1), (0, 0, 255), 2) + cv2.imwrite("res.jpg", original_image) + + + if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True, help="the yolo11 model path") + parser.add_argument("--img", type=str, required=True, help="the input image path") + parser.add_argument("--precision", type=str, default="normal", help="inference precision: normal, low, high, lowBF") + parser.add_argument( + "--backend", + type=str, + default="CPU", + help="inference backend: CPU, OPENCL, OPENGL, NN, VULKAN, METAL, TRT, CUDA, HIAI", + ) + parser.add_argument("--thread", type=int, default=4, help="inference using thread: int") + args = parser.parse_args() + inference(args.model, args.img, args.precision, args.backend, args.thread) + ``` + + === "CPP" + + ```cpp + #include + #include + #include + #include + #include + #include + + #include + + using namespace MNN; + using namespace MNN::Express; + using namespace MNN::CV; + + int main(int argc, const char* argv[]) { + if (argc < 3) { + MNN_PRINT("Usage: ./yolo11_demo.out model.mnn input.jpg [forwardType] [precision] [thread]\n"); + return 0; + } + int thread = 4; + int precision = 0; + int forwardType = MNN_FORWARD_CPU; + if (argc >= 4) { + forwardType = atoi(argv[3]); + } + if (argc >= 5) { + precision = atoi(argv[4]); + } + if (argc >= 6) { + thread = atoi(argv[5]); + } + MNN::ScheduleConfig sConfig; + sConfig.type = static_cast(forwardType); + sConfig.numThread = thread; + BackendConfig bConfig; + bConfig.precision = static_cast(precision); + sConfig.backendConfig = &bConfig; + std::shared_ptr rtmgr = std::shared_ptr(Executor::RuntimeManager::createRuntimeManager(sConfig)); + if(rtmgr == nullptr) { + MNN_ERROR("Empty RuntimeManger\n"); + return 0; + } + rtmgr->setCache(".cachefile"); + + std::shared_ptr net(Module::load(std::vector{}, std::vector{}, argv[1], rtmgr)); + auto original_image = imread(argv[2]); + auto dims = original_image->getInfo()->dim; + int ih = dims[0]; + int iw = dims[1]; + int len = ih > iw ? ih : iw; + float scale = len / 640.0; + std::vector padvals { 0, len - ih, 0, len - iw, 0, 0 }; + auto pads = _Const(static_cast(padvals.data()), {3, 2}, NCHW, halide_type_of()); + auto image = _Pad(original_image, pads, CONSTANT); + image = resize(image, Size(640, 640), 0, 0, INTER_LINEAR, -1, {0., 0., 0.}, {1./255., 1./255., 1./255.}); + auto input = _Unsqueeze(image, {0}); + input = _Convert(input, NC4HW4); + auto outputs = net->onForward({input}); + auto output = _Convert(outputs[0], NCHW); + output = _Squeeze(output); + // output shape: [84, 8400]; 84 means: [cx, cy, w, h, prob * 80] + auto cx = _Gather(output, _Scalar(0)); + auto cy = _Gather(output, _Scalar(1)); + auto w = _Gather(output, _Scalar(2)); + auto h = _Gather(output, _Scalar(3)); + std::vector startvals { 4, 0 }; + auto start = _Const(static_cast(startvals.data()), {2}, NCHW, halide_type_of()); + std::vector sizevals { -1, -1 }; + auto size = _Const(static_cast(sizevals.data()), {2}, NCHW, halide_type_of()); + auto probs = _Slice(output, start, size); + // [cx, cy, w, h] -> [y0, x0, y1, x1] + auto x0 = cx - w * _Const(0.5); + auto y0 = cy - h * _Const(0.5); + auto x1 = cx + w * _Const(0.5); + auto y1 = cy + h * _Const(0.5); + auto boxes = _Stack({x0, y0, x1, y1}, 1); + auto scores = _ReduceMax(probs, {0}); + auto ids = _ArgMax(probs, 0); + auto result_ids = _Nms(boxes, scores, 100, 0.45, 0.25); + auto result_ptr = result_ids->readMap(); + auto box_ptr = boxes->readMap(); + auto ids_ptr = ids->readMap(); + auto score_ptr = scores->readMap(); + for (int i = 0; i < 100; i++) { + auto idx = result_ptr[i]; + if (idx < 0) break; + auto x0 = box_ptr[idx * 4 + 0] * scale; + auto y0 = box_ptr[idx * 4 + 1] * scale; + auto x1 = box_ptr[idx * 4 + 2] * scale; + auto y1 = box_ptr[idx * 4 + 3] * scale; + auto class_idx = ids_ptr[idx]; + auto score = score_ptr[idx]; + rectangle(original_image, {x0, y0}, {x1, y1}, {0, 0, 255}, 2); + } + if (imwrite("res.jpg", original_image)) { + MNN_PRINT("result image write to `res.jpg`.\n"); + } + rtmgr->updateCache(); + return 0; + } + ``` + +## Summary + +In this guide, we introduce how to export the Ultralytics YOLO11 model to MNN and use MNN for inference. + +For more usage, please refer to the [MNN documentation](https://mnn-docs.readthedocs.io/en/latest). + +## FAQ + +### How do I export Ultralytics YOLO11 models to MNN format? + +To export your Ultralytics YOLO11 model to MNN format, follow these steps: + +!!! example "Export" + + === "Python" + + ```python + from ultralytics import YOLO + + # Load the YOLO11 model + model = YOLO("yolo11n.pt") + + # Export to MNN format + model.export(format="mnn") # creates 'yolo11n.mnn' with fp32 weight + model.export(format="mnn", half=True) # creates 'yolo11n.mnn' with fp16 weight + model.export(format="mnn", int8=True) # creates 'yolo11n.mnn' with int8 weight + ``` + + === "CLI" + + ```bash + yolo export model=yolo11n.pt format=mnn # creates 'yolo11n.mnn' with fp32 weight + yolo export model=yolo11n.pt format=mnn half=True # creates 'yolo11n.mnn' with fp16 weight + yolo export model=yolo11n.pt format=mnn int8=True # creates 'yolo11n.mnn' with int8 weight + ``` + +For detailed export options, check the [Export](../modes/export.md) page in the documentation. + +### How do I predict with an exported YOLO11 MNN model? + +To predict with an exported YOLO11 MNN model, use the `predict` function from the YOLO class. + +!!! example "Predict" + + === "Python" + + ```python + from ultralytics import YOLO + + # Load the YOLO11 MNN model + model = YOLO("yolo11n.mnn") + + # Export to MNN format + results = mnn_model("https://ultralytics.com/images/bus.jpg") # predict with `fp32` + results = mnn_model("https://ultralytics.com/images/bus.jpg", half=True) # predict with `fp16` if device support + + for result in results: + result.show() # display to screen + result.save(filename="result.jpg") # save to disk + ``` + + === "CLI" + + ```bash + yolo predict model='yolo11n.mnn' source='https://ultralytics.com/images/bus.jpg' # predict with `fp32` + yolo predict model='yolo11n.mnn' source='https://ultralytics.com/images/bus.jpg' --half=True # predict with `fp16` if device support + ``` + +### What platforms are supported for MNN? + +MNN is versatile and supports various platforms: + +- **Mobile**: Android, iOS, Harmony. +- **Embedded Systems and IoT Devices**: Devices like Raspberry Pi and NVIDIA Jetson. +- **Desktop and Servers**: Linux, Windows, and macOS. + +### How can I deploy Ultralytics YOLO11 MNN models on Mobile Devices? + +To deploy your YOLO11 models on Mobile devices: + +1. **Build for Android**: Follow the [MNN Android](https://github.com/alibaba/MNN/tree/master/project/android). +2. **Build for iOS**: Follow the [MNN iOS](https://github.com/alibaba/MNN/tree/master/project/ios). +3. **Build for Harmony**: Follow the [MNN Harmony](https://github.com/alibaba/MNN/tree/master/project/harmony). diff --git a/docs/en/macros/export-table.md b/docs/en/macros/export-table.md index 7cda31963..b7134f42b 100644 --- a/docs/en/macros/export-table.md +++ b/docs/en/macros/export-table.md @@ -12,4 +12,5 @@ | [TF Edge TPU](../integrations/edge-tpu.md) | `edgetpu` | `{{ model_name or "yolo11n" }}_edgetpu.tflite` | ✅ | `imgsz` | | [TF.js](../integrations/tfjs.md) | `tfjs` | `{{ model_name or "yolo11n" }}_web_model/` | ✅ | `imgsz`, `half`, `int8`, `batch` | | [PaddlePaddle](../integrations/paddlepaddle.md) | `paddle` | `{{ model_name or "yolo11n" }}_paddle_model/` | ✅ | `imgsz`, `batch` | +| [MNN](../integrations/mnn.md) | `mnn` | `{{ model_name or "yolo11n" }}.mnn` | ✅ | `imgsz`, `batch`, `int8`, `half` | | [NCNN](../integrations/ncnn.md) | `ncnn` | `{{ model_name or "yolo11n" }}_ncnn_model/` | ✅ | `imgsz`, `half`, `batch` | diff --git a/docs/mkdocs_github_authors.yaml b/docs/mkdocs_github_authors.yaml index 2e2092138..55ac6ec95 100644 --- a/docs/mkdocs_github_authors.yaml +++ b/docs/mkdocs_github_authors.yaml @@ -154,3 +154,6 @@ web@ultralytics.com: xinwang614@gmail.com: avatar: https://avatars.githubusercontent.com/u/17264618?v=4 username: GreatV +zhaode.wzd@alibaba-inc.com: + avatar: https://avatars.githubusercontent.com/u/8401806?v=4 + username: ZhaodeWang diff --git a/mkdocs.yml b/mkdocs.yml index a7157ec94..3ee15f83b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -398,11 +398,12 @@ nav: - JupyterLab: integrations/jupyterlab.md - Kaggle: integrations/kaggle.md - MLflow: integrations/mlflow.md - - NCNN: integrations/ncnn.md - Neural Magic: integrations/neural-magic.md - ONNX: integrations/onnx.md - OpenVINO: integrations/openvino.md - PaddlePaddle: integrations/paddlepaddle.md + - MNN: integrations/mnn.md + - NCNN: integrations/ncnn.md - Paperspace Gradient: integrations/paperspace.md - Ray Tune: integrations/ray-tune.md - Roboflow: integrations/roboflow.md diff --git a/tests/test_exports.py b/tests/test_exports.py index e6e2ec159..12443fa30 100644 --- a/tests/test_exports.py +++ b/tests/test_exports.py @@ -197,3 +197,10 @@ def test_export_ncnn(): """Test YOLO exports to NCNN format.""" file = YOLO(MODEL).export(format="ncnn", imgsz=32) YOLO(file)(SOURCE, imgsz=32) # exported model inference + + +@pytest.mark.slow +def test_export_mnn(): + """Test YOLO exports to MNN format.""" + file = YOLO(MODEL).export(format="mnn", imgsz=32) + YOLO(file)(SOURCE, imgsz=32) # exported model inference diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py index 72a939647..c847dd4d1 100644 --- a/ultralytics/__init__.py +++ b/ultralytics/__init__.py @@ -1,6 +1,6 @@ # Ultralytics YOLO 🚀, AGPL-3.0 license -__version__ = "8.3.24" +__version__ = "8.3.25" import os diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 5104de1cd..ea8d03b46 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -16,6 +16,7 @@ TensorFlow Lite | `tflite` | yolo11n.tflite TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite TensorFlow.js | `tfjs` | yolo11n_web_model/ PaddlePaddle | `paddle` | yolo11n_paddle_model/ +MNN | `mnn` | yolo11n.mnn NCNN | `ncnn` | yolo11n_ncnn_model/ Requirements: @@ -41,6 +42,7 @@ Inference: yolo11n.tflite # TensorFlow Lite yolo11n_edgetpu.tflite # TensorFlow Edge TPU yolo11n_paddle_model # PaddlePaddle + yolo11n.mnn # MNN yolo11n_ncnn_model # NCNN TensorFlow.js: @@ -109,6 +111,7 @@ def export_formats(): ["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False], ["TensorFlow.js", "tfjs", "_web_model", True, False], ["PaddlePaddle", "paddle", "_paddle_model", True, True], + ["MNN", "mnn", ".mnn", True, True], ["NCNN", "ncnn", "_ncnn_model", True, True], ] return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU"], zip(*x))) @@ -190,7 +193,9 @@ class Exporter: flags = [x == fmt for x in fmts] if sum(flags) != 1: raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}") - jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans + jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, mnn, ncnn = ( + flags # export booleans + ) is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs)) # Device @@ -333,8 +338,10 @@ class Exporter: f[9], _ = self.export_tfjs() if paddle: # PaddlePaddle f[10], _ = self.export_paddle() + if mnn: # MNN + f[11], _ = self.export_mnn() if ncnn: # NCNN - f[11], _ = self.export_ncnn() + f[12], _ = self.export_ncnn() # Finish f = [str(x) for x in f if x] # filter out '' and None @@ -541,6 +548,32 @@ class Exporter: yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml return f, None + @try_export + def export_mnn(self, prefix=colorstr("MNN:")): + """YOLOv8 MNN export using MNN https://github.com/alibaba/MNN.""" + f_onnx, _ = self.export_onnx() # get onnx model first + + check_requirements("MNN>=2.9.6") + import MNN # noqa + from MNN.tools import mnnconvert + + # Setup and checks + LOGGER.info(f"\n{prefix} starting export with MNN {MNN.version()}...") + assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}" + f = str(self.file.with_suffix(".mnn")) # MNN model file + args = ["", "-f", "ONNX", "--modelFile", f_onnx, "--MNNModel", f, "--bizCode", json.dumps(self.metadata)] + if self.args.int8: + args.append("--weightQuantBits") + args.append("8") + if self.args.half: + args.append("--fp16") + mnnconvert.convert(args) + # remove scratch file for model convert optimize + convert_scratch = Path(self.file.parent / ".__convert_external_data.bin") + if convert_scratch.exists(): + convert_scratch.unlink() + return f, None + @try_export def export_ncnn(self, prefix=colorstr("NCNN:")): """YOLO NCNN export using PNNX https://github.com/pnnx/pnnx.""" diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py index 16f12a88e..fbe593e06 100644 --- a/ultralytics/engine/predictor.py +++ b/ultralytics/engine/predictor.py @@ -26,6 +26,7 @@ Usage - formats: yolov8n.tflite # TensorFlow Lite yolov8n_edgetpu.tflite # TensorFlow Edge TPU yolov8n_paddle_model # PaddlePaddle + yolov8n.mnn # MNN yolov8n_ncnn_model # NCNN """ diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index daa058a9d..1f6f6912c 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -17,6 +17,7 @@ Usage - formats: yolov8n.tflite # TensorFlow Lite yolov8n_edgetpu.tflite # TensorFlow Edge TPU yolov8n_paddle_model # PaddlePaddle + yolov8n.mnn # MNN yolov8n_ncnn_model # NCNN """ diff --git a/ultralytics/nn/autobackend.py b/ultralytics/nn/autobackend.py index b9312fefd..245e42c4e 100644 --- a/ultralytics/nn/autobackend.py +++ b/ultralytics/nn/autobackend.py @@ -59,21 +59,22 @@ class AutoBackend(nn.Module): range of formats, each with specific naming conventions as outlined below: Supported Formats and Naming Conventions: - | Format | File Suffix | - |-----------------------|------------------| - | PyTorch | *.pt | - | TorchScript | *.torchscript | - | ONNX Runtime | *.onnx | - | ONNX OpenCV DNN | *.onnx (dnn=True)| - | OpenVINO | *openvino_model/ | - | CoreML | *.mlpackage | - | TensorRT | *.engine | - | TensorFlow SavedModel | *_saved_model | - | TensorFlow GraphDef | *.pb | - | TensorFlow Lite | *.tflite | - | TensorFlow Edge TPU | *_edgetpu.tflite | - | PaddlePaddle | *_paddle_model | - | NCNN | *_ncnn_model | + | Format | File Suffix | + |-----------------------|-------------------| + | PyTorch | *.pt | + | TorchScript | *.torchscript | + | ONNX Runtime | *.onnx | + | ONNX OpenCV DNN | *.onnx (dnn=True) | + | OpenVINO | *openvino_model/ | + | CoreML | *.mlpackage | + | TensorRT | *.engine | + | TensorFlow SavedModel | *_saved_model/ | + | TensorFlow GraphDef | *.pb | + | TensorFlow Lite | *.tflite | + | TensorFlow Edge TPU | *_edgetpu.tflite | + | PaddlePaddle | *_paddle_model/ | + | MNN | *.mnn | + | NCNN | *_ncnn_model/ | This class offers dynamic backend switching capabilities based on the input model format, making it easier to deploy models across various platforms. @@ -120,6 +121,7 @@ class AutoBackend(nn.Module): edgetpu, tfjs, paddle, + mnn, ncnn, triton, ) = self._model_type(w) @@ -403,6 +405,26 @@ class AutoBackend(nn.Module): output_names = predictor.get_output_names() metadata = w.parents[1] / "metadata.yaml" + # MNN + elif mnn: + LOGGER.info(f"Loading {w} for MNN inference...") + check_requirements("MNN") # requires MNN + import os + + import MNN + + config = {} + config["precision"] = "low" + config["backend"] = "CPU" + config["numThread"] = (os.cpu_count() + 1) // 2 + rt = MNN.nn.create_runtime_manager((config,)) + net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True) + + def torch_to_mnn(x): + return MNN.expr.const(x.data_ptr(), x.shape) + + metadata = json.loads(net.get_info()["bizCode"]) + # NCNN elif ncnn: LOGGER.info(f"Loading {w} for NCNN inference...") @@ -590,6 +612,12 @@ class AutoBackend(nn.Module): self.predictor.run() y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names] + # MNN + elif self.mnn: + input_var = self.torch_to_mnn(im) + output_var = self.net.onForward([input_var]) + y = [x.read() for x in output_var] + # NCNN elif self.ncnn: mat_in = self.pyncnn.Mat(im[0].cpu().numpy()) diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py index 653f48d3a..3ddd934db 100644 --- a/ultralytics/utils/benchmarks.py +++ b/ultralytics/utils/benchmarks.py @@ -21,6 +21,7 @@ TensorFlow Lite | `tflite` | yolov8n.tflite TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite TensorFlow.js | `tfjs` | yolov8n_web_model/ PaddlePaddle | `paddle` | yolov8n_paddle_model/ +MNN | `mnn` | yolov8n.mnn NCNN | `ncnn` | yolov8n_ncnn_model/ """ @@ -111,8 +112,8 @@ def benchmark( assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet" assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet" assert LINUX or MACOS, "Windows Paddle exports not supported yet" - if i in {12}: # NCNN - assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet" + if i in {12, 13}: # MNN, NCNN + assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN, NCNN exports not supported yet" if "cpu" in device.type: assert cpu, "inference not supported on CPU" if "cuda" in device.type: @@ -132,7 +133,7 @@ def benchmark( assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported" assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML - if i in {12}: + if i in {13}: assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet" exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)