|
|
@ -9,7 +9,7 @@ TorchScript | `torchscript` | yolov8n.torchscript |
|
|
|
ONNX | `onnx` | yolov8n.onnx |
|
|
|
ONNX | `onnx` | yolov8n.onnx |
|
|
|
OpenVINO | `openvino` | yolov8n_openvino_model/ |
|
|
|
OpenVINO | `openvino` | yolov8n_openvino_model/ |
|
|
|
TensorRT | `engine` | yolov8n.engine |
|
|
|
TensorRT | `engine` | yolov8n.engine |
|
|
|
CoreML | `coreml` | yolov8n.mlmodel |
|
|
|
CoreML | `coreml` | yolov8n.mlpackage |
|
|
|
TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/ |
|
|
|
TensorFlow SavedModel | `saved_model` | yolov8n_saved_model/ |
|
|
|
TensorFlow GraphDef | `pb` | yolov8n.pb |
|
|
|
TensorFlow GraphDef | `pb` | yolov8n.pb |
|
|
|
TensorFlow Lite | `tflite` | yolov8n.tflite |
|
|
|
TensorFlow Lite | `tflite` | yolov8n.tflite |
|
|
@ -35,7 +35,7 @@ Inference: |
|
|
|
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True |
|
|
|
yolov8n.onnx # ONNX Runtime or OpenCV DNN with dnn=True |
|
|
|
yolov8n_openvino_model # OpenVINO |
|
|
|
yolov8n_openvino_model # OpenVINO |
|
|
|
yolov8n.engine # TensorRT |
|
|
|
yolov8n.engine # TensorRT |
|
|
|
yolov8n.mlmodel # CoreML (macOS-only) |
|
|
|
yolov8n.mlpackage # CoreML (macOS-only) |
|
|
|
yolov8n_saved_model # TensorFlow SavedModel |
|
|
|
yolov8n_saved_model # TensorFlow SavedModel |
|
|
|
yolov8n.pb # TensorFlow GraphDef |
|
|
|
yolov8n.pb # TensorFlow GraphDef |
|
|
|
yolov8n.tflite # TensorFlow Lite |
|
|
|
yolov8n.tflite # TensorFlow Lite |
|
|
@ -82,7 +82,7 @@ def export_formats(): |
|
|
|
['ONNX', 'onnx', '.onnx', True, True], |
|
|
|
['ONNX', 'onnx', '.onnx', True, True], |
|
|
|
['OpenVINO', 'openvino', '_openvino_model', True, False], |
|
|
|
['OpenVINO', 'openvino', '_openvino_model', True, False], |
|
|
|
['TensorRT', 'engine', '.engine', False, True], |
|
|
|
['TensorRT', 'engine', '.engine', False, True], |
|
|
|
['CoreML', 'coreml', '.mlmodel', True, False], |
|
|
|
['CoreML', 'coreml', '.mlpackage', True, False], |
|
|
|
['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True], |
|
|
|
['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True], |
|
|
|
['TensorFlow GraphDef', 'pb', '.pb', True, True], |
|
|
|
['TensorFlow GraphDef', 'pb', '.pb', True, True], |
|
|
|
['TensorFlow Lite', 'tflite', '.tflite', True, False], |
|
|
|
['TensorFlow Lite', 'tflite', '.tflite', True, False], |
|
|
@ -149,8 +149,10 @@ class Exporter: |
|
|
|
self.run_callbacks('on_export_start') |
|
|
|
self.run_callbacks('on_export_start') |
|
|
|
t = time.time() |
|
|
|
t = time.time() |
|
|
|
format = self.args.format.lower() # to lowercase |
|
|
|
format = self.args.format.lower() # to lowercase |
|
|
|
if format in ('tensorrt', 'trt'): # engine aliases |
|
|
|
if format in ('tensorrt', 'trt'): # 'engine' aliases |
|
|
|
format = 'engine' |
|
|
|
format = 'engine' |
|
|
|
|
|
|
|
if format in ('mlmodel', 'mlpackage', 'mlprogram', 'apple', 'ios'): # 'coreml' aliases |
|
|
|
|
|
|
|
format = 'coreml' |
|
|
|
fmts = tuple(export_formats()['Argument'][1:]) # available export formats |
|
|
|
fmts = tuple(export_formats()['Argument'][1:]) # available export formats |
|
|
|
flags = [x == format for x in fmts] |
|
|
|
flags = [x == format for x in fmts] |
|
|
|
if sum(flags) != 1: |
|
|
|
if sum(flags) != 1: |
|
|
@ -319,7 +321,7 @@ class Exporter: |
|
|
|
dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 84, 8400) |
|
|
|
dynamic['output0'] = {0: 'batch', 2: 'anchors'} # shape(1, 84, 8400) |
|
|
|
|
|
|
|
|
|
|
|
torch.onnx.export( |
|
|
|
torch.onnx.export( |
|
|
|
self.model.cpu() if dynamic else self.model, # --dynamic only compatible with cpu |
|
|
|
self.model.cpu() if dynamic else self.model, # dynamic=True only compatible with cpu |
|
|
|
self.im.cpu() if dynamic else self.im, |
|
|
|
self.im.cpu() if dynamic else self.im, |
|
|
|
f, |
|
|
|
f, |
|
|
|
verbose=False, |
|
|
|
verbose=False, |
|
|
@ -461,14 +463,16 @@ class Exporter: |
|
|
|
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml |
|
|
|
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml |
|
|
|
return str(f), None |
|
|
|
return str(f), None |
|
|
|
|
|
|
|
|
|
|
|
@try_export |
|
|
|
|
|
|
|
def export_coreml(self, prefix=colorstr('CoreML:')): |
|
|
|
def export_coreml(self, prefix=colorstr('CoreML:')): |
|
|
|
"""YOLOv8 CoreML export.""" |
|
|
|
"""YOLOv8 CoreML export.""" |
|
|
|
check_requirements('coremltools>=6.0,<=6.2') |
|
|
|
mlmodel = self.args.format.lower() == 'mlmodel' # legacy *.mlmodel export format requested |
|
|
|
|
|
|
|
check_requirements('coremltools>=6.0,<=6.2' if mlmodel else 'coremltools>=7.0.b1') |
|
|
|
import coremltools as ct # noqa |
|
|
|
import coremltools as ct # noqa |
|
|
|
|
|
|
|
|
|
|
|
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') |
|
|
|
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') |
|
|
|
f = self.file.with_suffix('.mlmodel') |
|
|
|
f = self.file.with_suffix('.mlmodel' if mlmodel else '.mlpackage') |
|
|
|
|
|
|
|
if f.is_dir(): |
|
|
|
|
|
|
|
shutil.rmtree(f) |
|
|
|
|
|
|
|
|
|
|
|
bias = [0.0, 0.0, 0.0] |
|
|
|
bias = [0.0, 0.0, 0.0] |
|
|
|
scale = 1 / 255 |
|
|
|
scale = 1 / 255 |
|
|
@ -479,20 +483,38 @@ class Exporter: |
|
|
|
elif self.model.task == 'detect': |
|
|
|
elif self.model.task == 'detect': |
|
|
|
model = iOSDetectModel(self.model, self.im) if self.args.nms else self.model |
|
|
|
model = iOSDetectModel(self.model, self.im) if self.args.nms else self.model |
|
|
|
else: |
|
|
|
else: |
|
|
|
# TODO CoreML Segment and Pose model pipelining |
|
|
|
if self.args.nms: |
|
|
|
|
|
|
|
LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolov8n.pt'.") |
|
|
|
|
|
|
|
# TODO CoreML Segment and Pose model pipelining |
|
|
|
model = self.model |
|
|
|
model = self.model |
|
|
|
|
|
|
|
|
|
|
|
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model |
|
|
|
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model |
|
|
|
ct_model = ct.convert(ts, |
|
|
|
ct_model = ct.convert(ts, |
|
|
|
inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)], |
|
|
|
inputs=[ct.ImageType('image', shape=self.im.shape, scale=scale, bias=bias)], |
|
|
|
classifier_config=classifier_config) |
|
|
|
classifier_config=classifier_config, |
|
|
|
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None) |
|
|
|
convert_to='neuralnetwork' if mlmodel else 'mlprogram') |
|
|
|
|
|
|
|
bits, mode = (8, 'kmeans') if self.args.int8 else (16, 'linear') if self.args.half else (32, None) |
|
|
|
if bits < 32: |
|
|
|
if bits < 32: |
|
|
|
if 'kmeans' in mode: |
|
|
|
if 'kmeans' in mode: |
|
|
|
check_requirements('scikit-learn') # scikit-learn package required for k-means quantization |
|
|
|
check_requirements('scikit-learn') # scikit-learn package required for k-means quantization |
|
|
|
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) |
|
|
|
if mlmodel: |
|
|
|
|
|
|
|
ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
import coremltools.optimize.coreml as cto |
|
|
|
|
|
|
|
op_config = cto.OpPalettizerConfig(mode=mode, nbits=bits, weight_threshold=512) |
|
|
|
|
|
|
|
config = cto.OptimizationConfig(global_config=op_config) |
|
|
|
|
|
|
|
ct_model = cto.palettize_weights(ct_model, config=config) |
|
|
|
if self.args.nms and self.model.task == 'detect': |
|
|
|
if self.args.nms and self.model.task == 'detect': |
|
|
|
ct_model = self._pipeline_coreml(ct_model) |
|
|
|
if mlmodel: |
|
|
|
|
|
|
|
import platform |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# coremltools<=6.2 NMS export requires Python<3.11 |
|
|
|
|
|
|
|
check_version(platform.python_version(), '<3.11', name='Python ', hard=True) |
|
|
|
|
|
|
|
weights_dir = None |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
|
|
ct_model.save(str(f)) # save otherwise weights_dir does not exist |
|
|
|
|
|
|
|
weights_dir = str(f / 'Data/com.apple.CoreML/weights') |
|
|
|
|
|
|
|
ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir) |
|
|
|
|
|
|
|
|
|
|
|
m = self.metadata # metadata dict |
|
|
|
m = self.metadata # metadata dict |
|
|
|
ct_model.short_description = m.pop('description') |
|
|
|
ct_model.short_description = m.pop('description') |
|
|
@ -500,7 +522,14 @@ class Exporter: |
|
|
|
ct_model.license = m.pop('license') |
|
|
|
ct_model.license = m.pop('license') |
|
|
|
ct_model.version = m.pop('version') |
|
|
|
ct_model.version = m.pop('version') |
|
|
|
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()}) |
|
|
|
ct_model.user_defined_metadata.update({k: str(v) for k, v in m.items()}) |
|
|
|
ct_model.save(str(f)) |
|
|
|
try: |
|
|
|
|
|
|
|
ct_model.save(str(f)) # save *.mlpackage |
|
|
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
|
|
LOGGER.warning( |
|
|
|
|
|
|
|
f'{prefix} WARNING ⚠️ CoreML export to *.mlpackage failed ({e}), reverting to *.mlmodel export. ' |
|
|
|
|
|
|
|
f'Known coremltools Python 3.11 and Windows bugs https://github.com/apple/coremltools/issues/1928.') |
|
|
|
|
|
|
|
f = f.with_suffix('.mlmodel') |
|
|
|
|
|
|
|
ct_model.save(str(f)) |
|
|
|
return f, ct_model |
|
|
|
return f, ct_model |
|
|
|
|
|
|
|
|
|
|
|
@try_export |
|
|
|
@try_export |
|
|
@ -546,7 +575,7 @@ class Exporter: |
|
|
|
if self.args.dynamic: |
|
|
|
if self.args.dynamic: |
|
|
|
shape = self.im.shape |
|
|
|
shape = self.im.shape |
|
|
|
if shape[0] <= 1: |
|
|
|
if shape[0] <= 1: |
|
|
|
LOGGER.warning(f'{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument') |
|
|
|
LOGGER.warning(f"{prefix} WARNING ⚠️ 'dynamic=True' model requires max batch size, i.e. 'batch=16'") |
|
|
|
profile = builder.create_optimization_profile() |
|
|
|
profile = builder.create_optimization_profile() |
|
|
|
for inp in inputs: |
|
|
|
for inp in inputs: |
|
|
|
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape) |
|
|
|
profile.set_shape(inp.name, (1, *shape[1:]), (max(1, shape[0] // 2), *shape[1:]), shape) |
|
|
@ -805,7 +834,7 @@ class Exporter: |
|
|
|
populator.populate() |
|
|
|
populator.populate() |
|
|
|
tmp_file.unlink() |
|
|
|
tmp_file.unlink() |
|
|
|
|
|
|
|
|
|
|
|
def _pipeline_coreml(self, model, prefix=colorstr('CoreML Pipeline:')): |
|
|
|
def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr('CoreML Pipeline:')): |
|
|
|
"""YOLOv8 CoreML pipeline.""" |
|
|
|
"""YOLOv8 CoreML pipeline.""" |
|
|
|
import coremltools as ct # noqa |
|
|
|
import coremltools as ct # noqa |
|
|
|
|
|
|
|
|
|
|
@ -853,7 +882,7 @@ class Exporter: |
|
|
|
# print(spec.description) |
|
|
|
# print(spec.description) |
|
|
|
|
|
|
|
|
|
|
|
# Model from spec |
|
|
|
# Model from spec |
|
|
|
model = ct.models.MLModel(spec) |
|
|
|
model = ct.models.MLModel(spec, weights_dir=weights_dir) |
|
|
|
|
|
|
|
|
|
|
|
# 3. Create NMS protobuf |
|
|
|
# 3. Create NMS protobuf |
|
|
|
nms_spec = ct.proto.Model_pb2.Model() |
|
|
|
nms_spec = ct.proto.Model_pb2.Model() |
|
|
@ -912,7 +941,7 @@ class Exporter: |
|
|
|
'Confidence threshold': str(nms.confidenceThreshold)}) |
|
|
|
'Confidence threshold': str(nms.confidenceThreshold)}) |
|
|
|
|
|
|
|
|
|
|
|
# Save the model |
|
|
|
# Save the model |
|
|
|
model = ct.models.MLModel(pipeline.spec) |
|
|
|
model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir) |
|
|
|
model.input_description['image'] = 'Input image' |
|
|
|
model.input_description['image'] = 'Input image' |
|
|
|
model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})' |
|
|
|
model.input_description['iouThreshold'] = f'(optional) IOU threshold override (default: {nms.iouThreshold})' |
|
|
|
model.input_description['confidenceThreshold'] = \ |
|
|
|
model.input_description['confidenceThreshold'] = \ |
|
|
|