@ -18,6 +18,7 @@ TensorFlow.js | `tfjs` | yolo11n_web_model/
PaddlePaddle | ` paddle ` | yolo11n_paddle_model /
PaddlePaddle | ` paddle ` | yolo11n_paddle_model /
MNN | ` mnn ` | yolo11n . mnn
MNN | ` mnn ` | yolo11n . mnn
NCNN | ` ncnn ` | yolo11n_ncnn_model /
NCNN | ` ncnn ` | yolo11n_ncnn_model /
IMX | ` imx ` | yolo11n_imx_model /
Requirements :
Requirements :
$ pip install " ultralytics[export] "
$ pip install " ultralytics[export] "
@ -44,6 +45,7 @@ Inference:
yolo11n_paddle_model # PaddlePaddle
yolo11n_paddle_model # PaddlePaddle
yolo11n . mnn # MNN
yolo11n . mnn # MNN
yolo11n_ncnn_model # NCNN
yolo11n_ncnn_model # NCNN
yolo11n_imx_model # IMX
TensorFlow . js :
TensorFlow . js :
$ cd . . & & git clone https : / / github . com / zldrobit / tfjs - yolov5 - example . git & & cd tfjs - yolov5 - example
$ cd . . & & git clone https : / / github . com / zldrobit / tfjs - yolov5 - example . git & & cd tfjs - yolov5 - example
@ -94,7 +96,7 @@ from ultralytics.utils.checks import check_imgsz, check_is_path_safe, check_requ
from ultralytics . utils . downloads import attempt_download_asset , get_github_assets , safe_download
from ultralytics . utils . downloads import attempt_download_asset , get_github_assets , safe_download
from ultralytics . utils . files import file_size , spaces_in_path
from ultralytics . utils . files import file_size , spaces_in_path
from ultralytics . utils . ops import Profile
from ultralytics . utils . ops import Profile
from ultralytics . utils . torch_utils import TORCH_1_13 , get_latest_opset , select_device , smart_inference_mode
from ultralytics . utils . torch_utils import TORCH_1_13 , get_latest_opset , select_device
def export_formats ( ) :
def export_formats ( ) :
@ -114,6 +116,7 @@ def export_formats():
[ " PaddlePaddle " , " paddle " , " _paddle_model " , True , True ] ,
[ " PaddlePaddle " , " paddle " , " _paddle_model " , True , True ] ,
[ " MNN " , " mnn " , " .mnn " , True , True ] ,
[ " MNN " , " mnn " , " .mnn " , True , True ] ,
[ " NCNN " , " ncnn " , " _ncnn_model " , True , True ] ,
[ " NCNN " , " ncnn " , " _ncnn_model " , True , True ] ,
[ " IMX " , " imx " , " _imx_model " , True , True ] ,
]
]
return dict ( zip ( [ " Format " , " Argument " , " Suffix " , " CPU " , " GPU " ] , zip ( * x ) ) )
return dict ( zip ( [ " Format " , " Argument " , " Suffix " , " CPU " , " GPU " ] , zip ( * x ) ) )
@ -171,7 +174,6 @@ class Exporter:
self . callbacks = _callbacks or callbacks . get_default_callbacks ( )
self . callbacks = _callbacks or callbacks . get_default_callbacks ( )
callbacks . add_integration_callbacks ( self )
callbacks . add_integration_callbacks ( self )
@smart_inference_mode ( )
def __call__ ( self , model = None ) - > str :
def __call__ ( self , model = None ) - > str :
""" Returns list of exported files/dirs after running callbacks. """
""" Returns list of exported files/dirs after running callbacks. """
self . run_callbacks ( " on_export_start " )
self . run_callbacks ( " on_export_start " )
@ -194,9 +196,22 @@ class Exporter:
flags = [ x == fmt for x in fmts ]
flags = [ x == fmt for x in fmts ]
if sum ( flags ) != 1 :
if sum ( flags ) != 1 :
raise ValueError ( f " Invalid export format= ' { fmt } ' . Valid formats are { fmts } " )
raise ValueError ( f " Invalid export format= ' { fmt } ' . Valid formats are { fmts } " )
jit , onnx , xml , engine , coreml , saved_model , pb , tflite , edgetpu , tfjs , paddle , mnn , ncnn = (
(
flags # export booleans
jit ,
)
onnx ,
xml ,
engine ,
coreml ,
saved_model ,
pb ,
tflite ,
edgetpu ,
tfjs ,
paddle ,
mnn ,
ncnn ,
imx ,
) = flags # export booleans
is_tf_format = any ( ( saved_model , pb , tflite , edgetpu , tfjs ) )
is_tf_format = any ( ( saved_model , pb , tflite , edgetpu , tfjs ) )
# Device
# Device
@ -210,6 +225,9 @@ class Exporter:
self . device = select_device ( " cpu " if self . args . device is None else self . args . device )
self . device = select_device ( " cpu " if self . args . device is None else self . args . device )
# Checks
# Checks
if imx and not self . args . int8 :
LOGGER . warning ( " WARNING ⚠️ IMX only supports int8 export, setting int8=True. " )
self . args . int8 = True
if not hasattr ( model , " names " ) :
if not hasattr ( model , " names " ) :
model . names = default_class_names ( )
model . names = default_class_names ( )
model . names = check_class_names ( model . names )
model . names = check_class_names ( model . names )
@ -249,6 +267,7 @@ class Exporter:
)
)
if mnn and ( IS_RASPBERRYPI or IS_JETSON ) :
if mnn and ( IS_RASPBERRYPI or IS_JETSON ) :
raise SystemError ( " MNN export not supported on Raspberry Pi and NVIDIA Jetson " )
raise SystemError ( " MNN export not supported on Raspberry Pi and NVIDIA Jetson " )
# Input
# Input
im = torch . zeros ( self . args . batch , 3 , * self . imgsz ) . to ( self . device )
im = torch . zeros ( self . args . batch , 3 , * self . imgsz ) . to ( self . device )
file = Path (
file = Path (
@ -264,6 +283,11 @@ class Exporter:
model . eval ( )
model . eval ( )
model . float ( )
model . float ( )
model = model . fuse ( )
model = model . fuse ( )
if imx :
from ultralytics . utils . torch_utils import FXModel
model = FXModel ( model )
for m in model . modules ( ) :
for m in model . modules ( ) :
if isinstance ( m , ( Detect , RTDETRDecoder ) ) : # includes all Detect subclasses like Segment, Pose, OBB
if isinstance ( m , ( Detect , RTDETRDecoder ) ) : # includes all Detect subclasses like Segment, Pose, OBB
m . dynamic = self . args . dynamic
m . dynamic = self . args . dynamic
@ -273,6 +297,15 @@ class Exporter:
elif isinstance ( m , C2f ) and not is_tf_format :
elif isinstance ( m , C2f ) and not is_tf_format :
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
m . forward = m . forward_split
m . forward = m . forward_split
if isinstance ( m , Detect ) and imx :
from ultralytics . utils . tal import make_anchors
m . anchors , m . strides = (
x . transpose ( 0 , 1 )
for x in make_anchors (
torch . cat ( [ s / m . stride . unsqueeze ( - 1 ) for s in self . imgsz ] , dim = 1 ) , m . stride , 0.5
)
)
y = None
y = None
for _ in range ( 2 ) :
for _ in range ( 2 ) :
@ -347,6 +380,8 @@ class Exporter:
f [ 11 ] , _ = self . export_mnn ( )
f [ 11 ] , _ = self . export_mnn ( )
if ncnn : # NCNN
if ncnn : # NCNN
f [ 12 ] , _ = self . export_ncnn ( )
f [ 12 ] , _ = self . export_ncnn ( )
if imx :
f [ 13 ] , _ = self . export_imx ( )
# Finish
# Finish
f = [ str ( x ) for x in f if x ] # filter out '' and None
f = [ str ( x ) for x in f if x ] # filter out '' and None
@ -1068,6 +1103,137 @@ class Exporter:
yaml_save ( Path ( f ) / " metadata.yaml " , self . metadata ) # add metadata.yaml
yaml_save ( Path ( f ) / " metadata.yaml " , self . metadata ) # add metadata.yaml
return f , None
return f , None
@try_export
def export_imx ( self , prefix = colorstr ( " IMX: " ) ) :
""" YOLO IMX export. """
gptq = False
assert LINUX , " export only supported on Linux. See https://developer.aitrios.sony-semicon.com/en/raspberrypi-ai-camera/documentation/imx500-converter "
if getattr ( self . model , " end2end " , False ) :
raise ValueError ( " IMX export is not supported for end2end models. " )
if " C2f " not in self . model . __str__ ( ) :
raise ValueError ( " IMX export is only supported for YOLOv8 detection models " )
check_requirements ( ( " model-compression-toolkit==2.1.1 " , " sony-custom-layers==0.2.0 " , " tensorflow==2.12.0 " ) )
check_requirements ( " imx500-converter[pt]==3.14.3 " ) # Separate requirements for imx500-converter
import model_compression_toolkit as mct
import onnx
from sony_custom_layers . pytorch . object_detection . nms import multiclass_nms
try :
out = subprocess . run (
[ " java " , " --version " ] , check = True , capture_output = True
) # Java 17 is required for imx500-converter
if " openjdk 17 " not in str ( out . stdout ) :
raise FileNotFoundError
except FileNotFoundError :
subprocess . run ( [ " sudo " , " apt " , " install " , " -y " , " openjdk-17-jdk " , " openjdk-17-jre " ] , check = True )
def representative_dataset_gen ( dataloader = self . get_int8_calibration_dataloader ( prefix ) ) :
for batch in dataloader :
img = batch [ " img " ]
img = img / 255.0
yield [ img ]
tpc = mct . get_target_platform_capabilities (
fw_name = " pytorch " , target_platform_name = " imx500 " , target_platform_version = " v1 "
)
config = mct . core . CoreConfig (
mixed_precision_config = mct . core . MixedPrecisionQuantizationConfig ( num_of_images = 10 ) ,
quantization_config = mct . core . QuantizationConfig ( concat_threshold_update = True ) ,
)
resource_utilization = mct . core . ResourceUtilization ( weights_memory = 3146176 * 0.76 )
quant_model = (
mct . gptq . pytorch_gradient_post_training_quantization ( # Perform Gradient-Based Post Training Quantization
model = self . model ,
representative_data_gen = representative_dataset_gen ,
target_resource_utilization = resource_utilization ,
gptq_config = mct . gptq . get_pytorch_gptq_config ( n_epochs = 1000 , use_hessian_based_weights = False ) ,
core_config = config ,
target_platform_capabilities = tpc ,
) [ 0 ]
if gptq
else mct . ptq . pytorch_post_training_quantization ( # Perform post training quantization
in_module = self . model ,
representative_data_gen = representative_dataset_gen ,
target_resource_utilization = resource_utilization ,
core_config = config ,
target_platform_capabilities = tpc ,
) [ 0 ]
)
class NMSWrapper ( torch . nn . Module ) :
def __init__ (
self ,
model : torch . nn . Module ,
score_threshold : float = 0.001 ,
iou_threshold : float = 0.7 ,
max_detections : int = 300 ,
) :
"""
Wrapping PyTorch Module with multiclass_nms layer from sony_custom_layers .
Args :
model ( nn . Module ) : Model instance .
score_threshold ( float ) : Score threshold for non - maximum suppression .
iou_threshold ( float ) : Intersection over union threshold for non - maximum suppression .
max_detections ( float ) : The number of detections to return .
"""
super ( ) . __init__ ( )
self . model = model
self . score_threshold = score_threshold
self . iou_threshold = iou_threshold
self . max_detections = max_detections
def forward ( self , images ) :
# model inference
outputs = self . model ( images )
boxes = outputs [ 0 ]
scores = outputs [ 1 ]
nms = multiclass_nms (
boxes = boxes ,
scores = scores ,
score_threshold = self . score_threshold ,
iou_threshold = self . iou_threshold ,
max_detections = self . max_detections ,
)
return nms
quant_model = NMSWrapper (
model = quant_model ,
score_threshold = self . args . conf or 0.001 ,
iou_threshold = self . args . iou ,
max_detections = self . args . max_det ,
) . to ( self . device )
f = Path ( str ( self . file ) . replace ( self . file . suffix , " _imx_model " ) )
f . mkdir ( exist_ok = True )
onnx_model = f / Path ( str ( self . file ) . replace ( self . file . suffix , " _imx.onnx " ) ) # js dir
mct . exporter . pytorch_export_model (
model = quant_model , save_model_path = onnx_model , repr_dataset = representative_dataset_gen
)
model_onnx = onnx . load ( onnx_model ) # load onnx model
for k , v in self . metadata . items ( ) :
meta = model_onnx . metadata_props . add ( )
meta . key , meta . value = k , str ( v )
onnx . save ( model_onnx , onnx_model )
subprocess . run (
[ " imxconv-pt " , " -i " , str ( onnx_model ) , " -o " , str ( f ) , " --no-input-persistency " , " --overwrite-output " ] ,
check = True ,
)
# Needed for imx models.
with open ( f / " labels.txt " , " w " ) as file :
file . writelines ( [ f " { name } \n " for _ , name in self . model . names . items ( ) ] )
return f , None
def _add_tflite_metadata ( self , file ) :
def _add_tflite_metadata ( self , file ) :
""" Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata. """
""" Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata. """
import flatbuffers
import flatbuffers