`ultralytics 8.3.25` Alibaba MNN export and predict support (#16802)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Francesco Mattioli <Francesco.mttl@gmail.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/17284/head v8.3.25
王召德 4 weeks ago committed by GitHub
parent 11b4194344
commit 9c72d94ba4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      .gitignore
  2. 35
      docs/en/guides/model-deployment-options.md
  3. 6
      docs/en/integrations/index.md
  4. 342
      docs/en/integrations/mnn.md
  5. 1
      docs/en/macros/export-table.md
  6. 3
      docs/mkdocs_github_authors.yaml
  7. 3
      mkdocs.yml
  8. 7
      tests/test_exports.py
  9. 2
      ultralytics/__init__.py
  10. 37
      ultralytics/engine/exporter.py
  11. 1
      ultralytics/engine/predictor.py
  12. 1
      ultralytics/engine/validator.py
  13. 58
      ultralytics/nn/autobackend.py
  14. 7
      ultralytics/utils/benchmarks.py

1
.gitignore vendored

@ -157,6 +157,7 @@ weights/
*.torchscript
*.tflite
*.h5
*.mnn
*_saved_model/
*_web_model/
*_openvino_model/

@ -258,25 +258,30 @@ NCNN is a high-performance neural network inference framework optimized for the
- **Hardware Acceleration**: Tailored for ARM CPUs and GPUs, with specific optimizations for these architectures.
#### MNN
MNN is a highly efficient and lightweight deep learning framework. It supports inference and training of deep learning models and has industry-leading performance for inference and training on-device. In addition, MNN is also used on embedded devices, such as IoT.
## Comparative Analysis of YOLO11 Deployment Options
The following table provides a snapshot of the various deployment options available for YOLO11 models, helping you to assess which may best fit your project needs based on several critical criteria. For an in-depth look at each deployment option's format, please see the [Ultralytics documentation page on export formats](../modes/export.md#export-formats).
| Deployment Option | Performance Benchmarks | Compatibility and Integration | Community Support and Ecosystem | Case Studies | Maintenance and Updates | Security Considerations | Hardware Acceleration |
| ----------------- | ----------------------------------------------- | ---------------------------------------------- | --------------------------------------------- | ------------------------------------------ | ------------------------------------------- | ------------------------------------------------- | ---------------------------------- |
| PyTorch | Good flexibility; may trade off raw performance | Excellent with Python libraries | Extensive resources and community | Research and prototypes | Regular, active development | Dependent on deployment environment | CUDA support for GPU acceleration |
| TorchScript | Better for production than PyTorch | Smooth transition from PyTorch to C++ | Specialized but narrower than PyTorch | Industry where Python is a bottleneck | Consistent updates with PyTorch | Improved security without full Python | Inherits CUDA support from PyTorch |
| ONNX | Variable depending on runtime | High across different frameworks | Broad ecosystem, supported by many orgs | Flexibility across ML frameworks | Regular updates for new operations | Ensure secure conversion and deployment practices | Various hardware optimizations |
| OpenVINO | Optimized for Intel hardware | Best within Intel ecosystem | Solid in computer vision domain | IoT and edge with Intel hardware | Regular updates for Intel hardware | Robust features for sensitive applications | Tailored for Intel hardware |
| TensorRT | Top-tier on NVIDIA GPUs | Best for NVIDIA hardware | Strong network through NVIDIA | Real-time video and image inference | Frequent updates for new GPUs | Emphasis on security | Designed for NVIDIA GPUs |
| CoreML | Optimized for on-device Apple hardware | Exclusive to Apple ecosystem | Strong Apple and developer support | On-device ML on Apple products | Regular Apple updates | Focus on privacy and security | Apple neural engine and GPU |
| TF SavedModel | Scalable in server environments | Wide compatibility in TensorFlow ecosystem | Large support due to TensorFlow popularity | Serving models at scale | Regular updates by Google and community | Robust features for enterprise | Various hardware accelerations |
| TF GraphDef | Stable for static computation graphs | Integrates well with TensorFlow infrastructure | Resources for optimizing static graphs | Scenarios requiring static graphs | Updates alongside TensorFlow core | Established TensorFlow security practices | TensorFlow acceleration options |
| TF Lite | Speed and efficiency on mobile/embedded | Wide range of device support | Robust community, Google backed | Mobile applications with minimal footprint | Latest features for mobile | Secure environment on end-user devices | GPU and DSP among others |
| TF Edge TPU | Optimized for Google's Edge TPU hardware | Exclusive to Edge TPU devices | Growing with Google and third-party resources | IoT devices requiring real-time processing | Improvements for new Edge TPU hardware | Google's robust IoT security | Custom-designed for Google Coral |
| TF.js | Reasonable in-browser performance | High with web technologies | Web and Node.js developers support | Interactive web applications | TensorFlow team and community contributions | Web platform security model | Enhanced with WebGL and other APIs |
| PaddlePaddle | Competitive, easy to use and scalable | Baidu ecosystem, wide application support | Rapidly growing, especially in China | Chinese market and language processing | Focus on Chinese AI applications | Emphasizes data privacy and security | Including Baidu's Kunlun chips |
| NCNN | Optimized for mobile ARM-based devices | Mobile and embedded ARM systems | Niche but active mobile/embedded ML community | Android and ARM systems efficiency | High performance maintenance on ARM | On-device security advantages | ARM CPUs and GPUs optimizations |
| Deployment Option | Performance Benchmarks | Compatibility and Integration | Community Support and Ecosystem | Case Studies | Maintenance and Updates | Security Considerations | Hardware Acceleration |
| ----------------- | ----------------------------------------------- | ---------------------------------------------- | --------------------------------------------- | ------------------------------------------ | ---------------------------------------------- | ------------------------------------------------- | ---------------------------------- |
| PyTorch | Good flexibility; may trade off raw performance | Excellent with Python libraries | Extensive resources and community | Research and prototypes | Regular, active development | Dependent on deployment environment | CUDA support for GPU acceleration |
| TorchScript | Better for production than PyTorch | Smooth transition from PyTorch to C++ | Specialized but narrower than PyTorch | Industry where Python is a bottleneck | Consistent updates with PyTorch | Improved security without full Python | Inherits CUDA support from PyTorch |
| ONNX | Variable depending on runtime | High across different frameworks | Broad ecosystem, supported by many orgs | Flexibility across ML frameworks | Regular updates for new operations | Ensure secure conversion and deployment practices | Various hardware optimizations |
| OpenVINO | Optimized for Intel hardware | Best within Intel ecosystem | Solid in computer vision domain | IoT and edge with Intel hardware | Regular updates for Intel hardware | Robust features for sensitive applications | Tailored for Intel hardware |
| TensorRT | Top-tier on NVIDIA GPUs | Best for NVIDIA hardware | Strong network through NVIDIA | Real-time video and image inference | Frequent updates for new GPUs | Emphasis on security | Designed for NVIDIA GPUs |
| CoreML | Optimized for on-device Apple hardware | Exclusive to Apple ecosystem | Strong Apple and developer support | On-device ML on Apple products | Regular Apple updates | Focus on privacy and security | Apple neural engine and GPU |
| TF SavedModel | Scalable in server environments | Wide compatibility in TensorFlow ecosystem | Large support due to TensorFlow popularity | Serving models at scale | Regular updates by Google and community | Robust features for enterprise | Various hardware accelerations |
| TF GraphDef | Stable for static computation graphs | Integrates well with TensorFlow infrastructure | Resources for optimizing static graphs | Scenarios requiring static graphs | Updates alongside TensorFlow core | Established TensorFlow security practices | TensorFlow acceleration options |
| TF Lite | Speed and efficiency on mobile/embedded | Wide range of device support | Robust community, Google backed | Mobile applications with minimal footprint | Latest features for mobile | Secure environment on end-user devices | GPU and DSP among others |
| TF Edge TPU | Optimized for Google's Edge TPU hardware | Exclusive to Edge TPU devices | Growing with Google and third-party resources | IoT devices requiring real-time processing | Improvements for new Edge TPU hardware | Google's robust IoT security | Custom-designed for Google Coral |
| TF.js | Reasonable in-browser performance | High with web technologies | Web and Node.js developers support | Interactive web applications | TensorFlow team and community contributions | Web platform security model | Enhanced with WebGL and other APIs |
| PaddlePaddle | Competitive, easy to use and scalable | Baidu ecosystem, wide application support | Rapidly growing, especially in China | Chinese market and language processing | Focus on Chinese AI applications | Emphasizes data privacy and security | Including Baidu's Kunlun chips |
| MNN | High-performance for mobile devices. | Mobile and embedded ARM systems and X86-64 CPU | Mobile/embedded ML community | Moblile systems efficiency | High performance maintenance on Mobile Devices | On-device security advantages | ARM CPUs and GPUs optimizations |
| NCNN | Optimized for mobile ARM-based devices | Mobile and embedded ARM systems | Niche but active mobile/embedded ML community | Android and ARM systems efficiency | High performance maintenance on ARM | On-device security advantages | ARM CPUs and GPUs optimizations |
This comparative analysis gives you a high-level overview. For deployment, it's essential to consider the specific requirements and constraints of your project, and consult the detailed documentation and resources available for each option.

@ -57,6 +57,8 @@ Welcome to the Ultralytics Integrations page! This page provides an overview of
- [Weights & Biases (W&B)](weights-biases.md): Monitor experiments, visualize metrics, and foster reproducibility and collaboration on Ultralytics projects.
- [VS Code](vscode.md): An extension for VS Code that provides code snippets for accelerating development workflows with Ultralytics and also for anyone looking for examples to help learn or get started with Ultralytics.
## Deployment Integrations
- [CoreML](coreml.md): CoreML, developed by [Apple](https://www.apple.com/), is a framework designed for efficiently integrating machine learning models into applications across iOS, macOS, watchOS, and tvOS, using Apple's hardware for effective and secure [model deployment](https://www.ultralytics.com/glossary/model-deployment).
@ -65,6 +67,8 @@ Welcome to the Ultralytics Integrations page! This page provides an overview of
- [NCNN](ncnn.md): Developed by [Tencent](http://www.tencent.com/), NCNN is an efficient [neural network](https://www.ultralytics.com/glossary/neural-network-nn) inference framework tailored for mobile devices. It enables direct deployment of AI models into apps, optimizing performance across various mobile platforms.
- [MNN](mnn.md): Developed by [Alibaba](https://www.alibabagroup.com/), MNN is a highly efficient and lightweight deep learning framework. It supports inference and training of deep learning models and has industry-leading performance for inference and training on-device.
- [Neural Magic](neural-magic.md): Leverage Quantization Aware Training (QAT) and pruning techniques to optimize Ultralytics models for superior performance and leaner size.
- [ONNX](onnx.md): An open-source format created by [Microsoft](https://www.microsoft.com/) for facilitating the transfer of AI models between various frameworks, enhancing the versatility and deployment flexibility of Ultralytics models.
@ -87,8 +91,6 @@ Welcome to the Ultralytics Integrations page! This page provides an overview of
- [TorchScript](torchscript.md): Developed as part of the [PyTorch](https://pytorch.org/) framework, TorchScript enables efficient execution and deployment of machine learning models in various production environments without the need for Python dependencies.
- [VS Code](vscode.md): An extension for VS Code that provides code snippets for accelerating development workflows with Ultralytics and also for anyone looking for examples to help learn or get started with Ultralytics.
### Export Formats
We also support a variety of model export formats for deployment in different environments. Here are the available formats:

@ -0,0 +1,342 @@
---
comments: true
description: Optimize YOLO11 models for mobile and embedded devices by exporting to MNN format.
keywords: Ultralytics, YOLO11, MNN, model export, machine learning, deployment, mobile, embedded systems, deep learning, AI models
---
# MNN Export for YOLO11 Models and Deploy
## MNN
<p align="center">
<img width="100%" src="https://mnn-docs.readthedocs.io/en/latest/_images/architecture.png" alt="MNN architecture">
</p>
[MNN](https://github.com/alibaba/MNN) is a highly efficient and lightweight deep learning framework. It supports inference and training of deep learning models and has industry-leading performance for inference and training on-device. At present, MNN has been integrated into more than 30 apps of Alibaba Inc, such as Taobao, Tmall, Youku, DingTalk, Xianyu, etc., covering more than 70 usage scenarios such as live broadcast, short video capture, search recommendation, product searching by image, interactive marketing, equity distribution, security risk control. In addition, MNN is also used on embedded devices, such as IoT.
## Export to MNN: Converting Your YOLO11 Model
You can expand model compatibility and deployment flexibility by converting YOLO11 models to MNN format.
### Installation
To install the required packages, run:
!!! tip "Installation"
=== "CLI"
```bash
# Install the required package for YOLO11 and MNN
pip install ultralytics
pip install MNN
```
### Usage
Before diving into the usage instructions, it's important to note that while all [Ultralytics YOLO11 models](../models/index.md) are available for exporting, you can ensure that the model you select supports export functionality [here](../modes/export.md).
!!! example "Usage"
=== "Python"
```python
from ultralytics import YOLO
# Load the YOLO11 model
model = YOLO("yolo11n.pt")
# Export the model to MNN format
model.export(format="mnn") # creates 'yolo11n.mnn'
# Load the exported MNN model
mnn_model = YOLO("yolo11n.mnn")
# Run inference
results = mnn_model("https://ultralytics.com/images/bus.jpg")
```
=== "CLI"
```bash
# Export a YOLO11n PyTorch model to MNN format
yolo export model=yolo11n.pt format=mnn # creates 'yolo11n.mnn'
# Run inference with the exported model
yolo predict model='yolo11n.mnn' source='https://ultralytics.com/images/bus.jpg'
```
For more details about supported export options, visit the [Ultralytics documentation page on deployment options](../guides/model-deployment-options.md).
### MNN-Only Inference
A function that relies solely on MNN for YOLO11 inference and preprocessing is implemented, providing both Python and C++ versions for easy deployment in any scenario.
!!! example "MNN"
=== "Python"
```python
import argparse
import MNN
import MNN.cv as cv2
import MNN.numpy as np
def inference(model, img, precision, backend, thread):
config = {}
config["precision"] = precision
config["backend"] = backend
config["numThread"] = thread
rt = MNN.nn.create_runtime_manager((config,))
# net = MNN.nn.load_module_from_file(model, ['images'], ['output0'], runtime_manager=rt)
net = MNN.nn.load_module_from_file(model, [], [], runtime_manager=rt)
original_image = cv2.imread(img)
ih, iw, _ = original_image.shape
length = max((ih, iw))
scale = length / 640
image = np.pad(original_image, [[0, length - ih], [0, length - iw], [0, 0]], "constant")
image = cv2.resize(
image, (640, 640), 0.0, 0.0, cv2.INTER_LINEAR, -1, [0.0, 0.0, 0.0], [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0]
)
input_var = np.expand_dims(image, 0)
input_var = MNN.expr.convert(input_var, MNN.expr.NC4HW4)
output_var = net.forward(input_var)
output_var = MNN.expr.convert(output_var, MNN.expr.NCHW)
output_var = output_var.squeeze()
# output_var shape: [84, 8400]; 84 means: [cx, cy, w, h, prob * 80]
cx = output_var[0]
cy = output_var[1]
w = output_var[2]
h = output_var[3]
probs = output_var[4:]
# [cx, cy, w, h] -> [y0, x0, y1, x1]
x0 = cx - w * 0.5
y0 = cy - h * 0.5
x1 = cx + w * 0.5
y1 = cy + h * 0.5
boxes = np.stack([x0, y0, x1, y1], axis=1)
# get max prob and idx
scores = np.max(probs, 0)
class_ids = np.argmax(probs, 0)
result_ids = MNN.expr.nms(boxes, scores, 100, 0.45, 0.25)
print(result_ids.shape)
# nms result box, score, ids
result_boxes = boxes[result_ids]
result_scores = scores[result_ids]
result_class_ids = class_ids[result_ids]
for i in range(len(result_boxes)):
x0, y0, x1, y1 = result_boxes[i].read_as_tuple()
y0 = int(y0 * scale)
y1 = int(y1 * scale)
x0 = int(x0 * scale)
x1 = int(x1 * scale)
print(result_class_ids[i])
cv2.rectangle(original_image, (x0, y0), (x1, y1), (0, 0, 255), 2)
cv2.imwrite("res.jpg", original_image)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True, help="the yolo11 model path")
parser.add_argument("--img", type=str, required=True, help="the input image path")
parser.add_argument("--precision", type=str, default="normal", help="inference precision: normal, low, high, lowBF")
parser.add_argument(
"--backend",
type=str,
default="CPU",
help="inference backend: CPU, OPENCL, OPENGL, NN, VULKAN, METAL, TRT, CUDA, HIAI",
)
parser.add_argument("--thread", type=int, default=4, help="inference using thread: int")
args = parser.parse_args()
inference(args.model, args.img, args.precision, args.backend, args.thread)
```
=== "CPP"
```cpp
#include <stdio.h>
#include <MNN/ImageProcess.hpp>
#include <MNN/expr/Module.hpp>
#include <MNN/expr/Executor.hpp>
#include <MNN/expr/ExprCreator.hpp>
#include <MNN/expr/Executor.hpp>
#include <cv/cv.hpp>
using namespace MNN;
using namespace MNN::Express;
using namespace MNN::CV;
int main(int argc, const char* argv[]) {
if (argc < 3) {
MNN_PRINT("Usage: ./yolo11_demo.out model.mnn input.jpg [forwardType] [precision] [thread]\n");
return 0;
}
int thread = 4;
int precision = 0;
int forwardType = MNN_FORWARD_CPU;
if (argc >= 4) {
forwardType = atoi(argv[3]);
}
if (argc >= 5) {
precision = atoi(argv[4]);
}
if (argc >= 6) {
thread = atoi(argv[5]);
}
MNN::ScheduleConfig sConfig;
sConfig.type = static_cast<MNNForwardType>(forwardType);
sConfig.numThread = thread;
BackendConfig bConfig;
bConfig.precision = static_cast<BackendConfig::PrecisionMode>(precision);
sConfig.backendConfig = &bConfig;
std::shared_ptr<Executor::RuntimeManager> rtmgr = std::shared_ptr<Executor::RuntimeManager>(Executor::RuntimeManager::createRuntimeManager(sConfig));
if(rtmgr == nullptr) {
MNN_ERROR("Empty RuntimeManger\n");
return 0;
}
rtmgr->setCache(".cachefile");
std::shared_ptr<Module> net(Module::load(std::vector<std::string>{}, std::vector<std::string>{}, argv[1], rtmgr));
auto original_image = imread(argv[2]);
auto dims = original_image->getInfo()->dim;
int ih = dims[0];
int iw = dims[1];
int len = ih > iw ? ih : iw;
float scale = len / 640.0;
std::vector<int> padvals { 0, len - ih, 0, len - iw, 0, 0 };
auto pads = _Const(static_cast<void*>(padvals.data()), {3, 2}, NCHW, halide_type_of<int>());
auto image = _Pad(original_image, pads, CONSTANT);
image = resize(image, Size(640, 640), 0, 0, INTER_LINEAR, -1, {0., 0., 0.}, {1./255., 1./255., 1./255.});
auto input = _Unsqueeze(image, {0});
input = _Convert(input, NC4HW4);
auto outputs = net->onForward({input});
auto output = _Convert(outputs[0], NCHW);
output = _Squeeze(output);
// output shape: [84, 8400]; 84 means: [cx, cy, w, h, prob * 80]
auto cx = _Gather(output, _Scalar<int>(0));
auto cy = _Gather(output, _Scalar<int>(1));
auto w = _Gather(output, _Scalar<int>(2));
auto h = _Gather(output, _Scalar<int>(3));
std::vector<int> startvals { 4, 0 };
auto start = _Const(static_cast<void*>(startvals.data()), {2}, NCHW, halide_type_of<int>());
std::vector<int> sizevals { -1, -1 };
auto size = _Const(static_cast<void*>(sizevals.data()), {2}, NCHW, halide_type_of<int>());
auto probs = _Slice(output, start, size);
// [cx, cy, w, h] -> [y0, x0, y1, x1]
auto x0 = cx - w * _Const(0.5);
auto y0 = cy - h * _Const(0.5);
auto x1 = cx + w * _Const(0.5);
auto y1 = cy + h * _Const(0.5);
auto boxes = _Stack({x0, y0, x1, y1}, 1);
auto scores = _ReduceMax(probs, {0});
auto ids = _ArgMax(probs, 0);
auto result_ids = _Nms(boxes, scores, 100, 0.45, 0.25);
auto result_ptr = result_ids->readMap<int>();
auto box_ptr = boxes->readMap<float>();
auto ids_ptr = ids->readMap<int>();
auto score_ptr = scores->readMap<float>();
for (int i = 0; i < 100; i++) {
auto idx = result_ptr[i];
if (idx < 0) break;
auto x0 = box_ptr[idx * 4 + 0] * scale;
auto y0 = box_ptr[idx * 4 + 1] * scale;
auto x1 = box_ptr[idx * 4 + 2] * scale;
auto y1 = box_ptr[idx * 4 + 3] * scale;
auto class_idx = ids_ptr[idx];
auto score = score_ptr[idx];
rectangle(original_image, {x0, y0}, {x1, y1}, {0, 0, 255}, 2);
}
if (imwrite("res.jpg", original_image)) {
MNN_PRINT("result image write to `res.jpg`.\n");
}
rtmgr->updateCache();
return 0;
}
```
## Summary
In this guide, we introduce how to export the Ultralytics YOLO11 model to MNN and use MNN for inference.
For more usage, please refer to the [MNN documentation](https://mnn-docs.readthedocs.io/en/latest).
## FAQ
### How do I export Ultralytics YOLO11 models to MNN format?
To export your Ultralytics YOLO11 model to MNN format, follow these steps:
!!! example "Export"
=== "Python"
```python
from ultralytics import YOLO
# Load the YOLO11 model
model = YOLO("yolo11n.pt")
# Export to MNN format
model.export(format="mnn") # creates 'yolo11n.mnn' with fp32 weight
model.export(format="mnn", half=True) # creates 'yolo11n.mnn' with fp16 weight
model.export(format="mnn", int8=True) # creates 'yolo11n.mnn' with int8 weight
```
=== "CLI"
```bash
yolo export model=yolo11n.pt format=mnn # creates 'yolo11n.mnn' with fp32 weight
yolo export model=yolo11n.pt format=mnn half=True # creates 'yolo11n.mnn' with fp16 weight
yolo export model=yolo11n.pt format=mnn int8=True # creates 'yolo11n.mnn' with int8 weight
```
For detailed export options, check the [Export](../modes/export.md) page in the documentation.
### How do I predict with an exported YOLO11 MNN model?
To predict with an exported YOLO11 MNN model, use the `predict` function from the YOLO class.
!!! example "Predict"
=== "Python"
```python
from ultralytics import YOLO
# Load the YOLO11 MNN model
model = YOLO("yolo11n.mnn")
# Export to MNN format
results = mnn_model("https://ultralytics.com/images/bus.jpg") # predict with `fp32`
results = mnn_model("https://ultralytics.com/images/bus.jpg", half=True) # predict with `fp16` if device support
for result in results:
result.show() # display to screen
result.save(filename="result.jpg") # save to disk
```
=== "CLI"
```bash
yolo predict model='yolo11n.mnn' source='https://ultralytics.com/images/bus.jpg' # predict with `fp32`
yolo predict model='yolo11n.mnn' source='https://ultralytics.com/images/bus.jpg' --half=True # predict with `fp16` if device support
```
### What platforms are supported for MNN?
MNN is versatile and supports various platforms:
- **Mobile**: Android, iOS, Harmony.
- **Embedded Systems and IoT Devices**: Devices like Raspberry Pi and NVIDIA Jetson.
- **Desktop and Servers**: Linux, Windows, and macOS.
### How can I deploy Ultralytics YOLO11 MNN models on Mobile Devices?
To deploy your YOLO11 models on Mobile devices:
1. **Build for Android**: Follow the [MNN Android](https://github.com/alibaba/MNN/tree/master/project/android).
2. **Build for iOS**: Follow the [MNN iOS](https://github.com/alibaba/MNN/tree/master/project/ios).
3. **Build for Harmony**: Follow the [MNN Harmony](https://github.com/alibaba/MNN/tree/master/project/harmony).

@ -12,4 +12,5 @@
| [TF Edge TPU](../integrations/edge-tpu.md) | `edgetpu` | `{{ model_name or "yolo11n" }}_edgetpu.tflite` | ✅ | `imgsz` |
| [TF.js](../integrations/tfjs.md) | `tfjs` | `{{ model_name or "yolo11n" }}_web_model/` | ✅ | `imgsz`, `half`, `int8`, `batch` |
| [PaddlePaddle](../integrations/paddlepaddle.md) | `paddle` | `{{ model_name or "yolo11n" }}_paddle_model/` | ✅ | `imgsz`, `batch` |
| [MNN](../integrations/mnn.md) | `mnn` | `{{ model_name or "yolo11n" }}.mnn` | ✅ | `imgsz`, `batch`, `int8`, `half` |
| [NCNN](../integrations/ncnn.md) | `ncnn` | `{{ model_name or "yolo11n" }}_ncnn_model/` | ✅ | `imgsz`, `half`, `batch` |

@ -154,3 +154,6 @@ web@ultralytics.com:
xinwang614@gmail.com:
avatar: https://avatars.githubusercontent.com/u/17264618?v=4
username: GreatV
zhaode.wzd@alibaba-inc.com:
avatar: https://avatars.githubusercontent.com/u/8401806?v=4
username: ZhaodeWang

@ -398,11 +398,12 @@ nav:
- JupyterLab: integrations/jupyterlab.md
- Kaggle: integrations/kaggle.md
- MLflow: integrations/mlflow.md
- NCNN: integrations/ncnn.md
- Neural Magic: integrations/neural-magic.md
- ONNX: integrations/onnx.md
- OpenVINO: integrations/openvino.md
- PaddlePaddle: integrations/paddlepaddle.md
- MNN: integrations/mnn.md
- NCNN: integrations/ncnn.md
- Paperspace Gradient: integrations/paperspace.md
- Ray Tune: integrations/ray-tune.md
- Roboflow: integrations/roboflow.md

@ -197,3 +197,10 @@ def test_export_ncnn():
"""Test YOLO exports to NCNN format."""
file = YOLO(MODEL).export(format="ncnn", imgsz=32)
YOLO(file)(SOURCE, imgsz=32) # exported model inference
@pytest.mark.slow
def test_export_mnn():
"""Test YOLO exports to MNN format."""
file = YOLO(MODEL).export(format="mnn", imgsz=32)
YOLO(file)(SOURCE, imgsz=32) # exported model inference

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.24"
__version__ = "8.3.25"
import os

@ -16,6 +16,7 @@ TensorFlow Lite | `tflite` | yolo11n.tflite
TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite
TensorFlow.js | `tfjs` | yolo11n_web_model/
PaddlePaddle | `paddle` | yolo11n_paddle_model/
MNN | `mnn` | yolo11n.mnn
NCNN | `ncnn` | yolo11n_ncnn_model/
Requirements:
@ -41,6 +42,7 @@ Inference:
yolo11n.tflite # TensorFlow Lite
yolo11n_edgetpu.tflite # TensorFlow Edge TPU
yolo11n_paddle_model # PaddlePaddle
yolo11n.mnn # MNN
yolo11n_ncnn_model # NCNN
TensorFlow.js:
@ -109,6 +111,7 @@ def export_formats():
["TensorFlow Edge TPU", "edgetpu", "_edgetpu.tflite", True, False],
["TensorFlow.js", "tfjs", "_web_model", True, False],
["PaddlePaddle", "paddle", "_paddle_model", True, True],
["MNN", "mnn", ".mnn", True, True],
["NCNN", "ncnn", "_ncnn_model", True, True],
]
return dict(zip(["Format", "Argument", "Suffix", "CPU", "GPU"], zip(*x)))
@ -190,7 +193,9 @@ class Exporter:
flags = [x == fmt for x in fmts]
if sum(flags) != 1:
raise ValueError(f"Invalid export format='{fmt}'. Valid formats are {fmts}")
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, mnn, ncnn = (
flags # export booleans
)
is_tf_format = any((saved_model, pb, tflite, edgetpu, tfjs))
# Device
@ -333,8 +338,10 @@ class Exporter:
f[9], _ = self.export_tfjs()
if paddle: # PaddlePaddle
f[10], _ = self.export_paddle()
if mnn: # MNN
f[11], _ = self.export_mnn()
if ncnn: # NCNN
f[11], _ = self.export_ncnn()
f[12], _ = self.export_ncnn()
# Finish
f = [str(x) for x in f if x] # filter out '' and None
@ -541,6 +548,32 @@ class Exporter:
yaml_save(Path(f) / "metadata.yaml", self.metadata) # add metadata.yaml
return f, None
@try_export
def export_mnn(self, prefix=colorstr("MNN:")):
"""YOLOv8 MNN export using MNN https://github.com/alibaba/MNN."""
f_onnx, _ = self.export_onnx() # get onnx model first
check_requirements("MNN>=2.9.6")
import MNN # noqa
from MNN.tools import mnnconvert
# Setup and checks
LOGGER.info(f"\n{prefix} starting export with MNN {MNN.version()}...")
assert Path(f_onnx).exists(), f"failed to export ONNX file: {f_onnx}"
f = str(self.file.with_suffix(".mnn")) # MNN model file
args = ["", "-f", "ONNX", "--modelFile", f_onnx, "--MNNModel", f, "--bizCode", json.dumps(self.metadata)]
if self.args.int8:
args.append("--weightQuantBits")
args.append("8")
if self.args.half:
args.append("--fp16")
mnnconvert.convert(args)
# remove scratch file for model convert optimize
convert_scratch = Path(self.file.parent / ".__convert_external_data.bin")
if convert_scratch.exists():
convert_scratch.unlink()
return f, None
@try_export
def export_ncnn(self, prefix=colorstr("NCNN:")):
"""YOLO NCNN export using PNNX https://github.com/pnnx/pnnx."""

@ -26,6 +26,7 @@ Usage - formats:
yolov8n.tflite # TensorFlow Lite
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
yolov8n_paddle_model # PaddlePaddle
yolov8n.mnn # MNN
yolov8n_ncnn_model # NCNN
"""

@ -17,6 +17,7 @@ Usage - formats:
yolov8n.tflite # TensorFlow Lite
yolov8n_edgetpu.tflite # TensorFlow Edge TPU
yolov8n_paddle_model # PaddlePaddle
yolov8n.mnn # MNN
yolov8n_ncnn_model # NCNN
"""

@ -59,21 +59,22 @@ class AutoBackend(nn.Module):
range of formats, each with specific naming conventions as outlined below:
Supported Formats and Naming Conventions:
| Format | File Suffix |
|-----------------------|------------------|
| PyTorch | *.pt |
| TorchScript | *.torchscript |
| ONNX Runtime | *.onnx |
| ONNX OpenCV DNN | *.onnx (dnn=True)|
| OpenVINO | *openvino_model/ |
| CoreML | *.mlpackage |
| TensorRT | *.engine |
| TensorFlow SavedModel | *_saved_model |
| TensorFlow GraphDef | *.pb |
| TensorFlow Lite | *.tflite |
| TensorFlow Edge TPU | *_edgetpu.tflite |
| PaddlePaddle | *_paddle_model |
| NCNN | *_ncnn_model |
| Format | File Suffix |
|-----------------------|-------------------|
| PyTorch | *.pt |
| TorchScript | *.torchscript |
| ONNX Runtime | *.onnx |
| ONNX OpenCV DNN | *.onnx (dnn=True) |
| OpenVINO | *openvino_model/ |
| CoreML | *.mlpackage |
| TensorRT | *.engine |
| TensorFlow SavedModel | *_saved_model/ |
| TensorFlow GraphDef | *.pb |
| TensorFlow Lite | *.tflite |
| TensorFlow Edge TPU | *_edgetpu.tflite |
| PaddlePaddle | *_paddle_model/ |
| MNN | *.mnn |
| NCNN | *_ncnn_model/ |
This class offers dynamic backend switching capabilities based on the input model format, making it easier to deploy
models across various platforms.
@ -120,6 +121,7 @@ class AutoBackend(nn.Module):
edgetpu,
tfjs,
paddle,
mnn,
ncnn,
triton,
) = self._model_type(w)
@ -403,6 +405,26 @@ class AutoBackend(nn.Module):
output_names = predictor.get_output_names()
metadata = w.parents[1] / "metadata.yaml"
# MNN
elif mnn:
LOGGER.info(f"Loading {w} for MNN inference...")
check_requirements("MNN") # requires MNN
import os
import MNN
config = {}
config["precision"] = "low"
config["backend"] = "CPU"
config["numThread"] = (os.cpu_count() + 1) // 2
rt = MNN.nn.create_runtime_manager((config,))
net = MNN.nn.load_module_from_file(w, [], [], runtime_manager=rt, rearrange=True)
def torch_to_mnn(x):
return MNN.expr.const(x.data_ptr(), x.shape)
metadata = json.loads(net.get_info()["bizCode"])
# NCNN
elif ncnn:
LOGGER.info(f"Loading {w} for NCNN inference...")
@ -590,6 +612,12 @@ class AutoBackend(nn.Module):
self.predictor.run()
y = [self.predictor.get_output_handle(x).copy_to_cpu() for x in self.output_names]
# MNN
elif self.mnn:
input_var = self.torch_to_mnn(im)
output_var = self.net.onForward([input_var])
y = [x.read() for x in output_var]
# NCNN
elif self.ncnn:
mat_in = self.pyncnn.Mat(im[0].cpu().numpy())

@ -21,6 +21,7 @@ TensorFlow Lite | `tflite` | yolov8n.tflite
TensorFlow Edge TPU | `edgetpu` | yolov8n_edgetpu.tflite
TensorFlow.js | `tfjs` | yolov8n_web_model/
PaddlePaddle | `paddle` | yolov8n_paddle_model/
MNN | `mnn` | yolov8n.mnn
NCNN | `ncnn` | yolov8n_ncnn_model/
"""
@ -111,8 +112,8 @@ def benchmark(
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
assert LINUX or MACOS, "Windows Paddle exports not supported yet"
if i in {12}: # NCNN
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
if i in {12, 13}: # MNN, NCNN
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN, NCNN exports not supported yet"
if "cpu" in device.type:
assert cpu, "inference not supported on CPU"
if "cuda" in device.type:
@ -132,7 +133,7 @@ def benchmark(
assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported
assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
if i in {12}:
if i in {13}:
assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet"
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)

Loading…
Cancel
Save