|
|
|
@ -65,6 +65,7 @@ import pandas as pd |
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
import ultralytics |
|
|
|
|
from ultralytics.nn.autobackend import check_class_names |
|
|
|
|
from ultralytics.nn.modules import Detect, Segment |
|
|
|
|
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, guess_model_task |
|
|
|
|
from ultralytics.yolo.cfg import get_cfg |
|
|
|
@ -151,9 +152,12 @@ class Exporter: |
|
|
|
|
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic' |
|
|
|
|
|
|
|
|
|
# Checks |
|
|
|
|
model.names = check_class_names(model.names) |
|
|
|
|
# if self.args.batch == model.args['batch_size']: # user has not modified training batch_size |
|
|
|
|
self.args.batch = 1 |
|
|
|
|
self.imgsz = check_imgsz(self.args.imgsz, stride=model.stride, min_dim=2) # check image size |
|
|
|
|
if model.task == 'classify': |
|
|
|
|
self.args.nms = self.args.agnostic_nms = False |
|
|
|
|
if self.args.optimize: |
|
|
|
|
assert self.device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu' |
|
|
|
|
|
|
|
|
@ -194,8 +198,14 @@ class Exporter: |
|
|
|
|
self.model = model |
|
|
|
|
self.file = file |
|
|
|
|
self.output_shape = tuple(y.shape) if isinstance(y, torch.Tensor) else (x.shape for x in y) |
|
|
|
|
self.metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata |
|
|
|
|
self.pretty_name = self.file.stem.replace('yolo', 'YOLO') |
|
|
|
|
self.metadata = { |
|
|
|
|
'description': f"Ultralytics {self.pretty_name} model trained on {self.model.args['data']}", |
|
|
|
|
'author': 'Ultralytics', |
|
|
|
|
'license': 'GPL-3.0 https://ultralytics.com/license', |
|
|
|
|
'version': ultralytics.__version__, |
|
|
|
|
'stride': int(max(model.stride)), |
|
|
|
|
'names': model.names} # model metadata |
|
|
|
|
|
|
|
|
|
# Exports |
|
|
|
|
f = [''] * len(fmts) # exported filenames |
|
|
|
@ -235,12 +245,11 @@ class Exporter: |
|
|
|
|
# Finish |
|
|
|
|
f = [str(x) for x in f if x] # filter out '' and None |
|
|
|
|
if any(f): |
|
|
|
|
task = guess_model_task(model) |
|
|
|
|
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models" |
|
|
|
|
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' |
|
|
|
|
f"\nResults saved to {colorstr('bold', file.parent.resolve())}" |
|
|
|
|
f"\nPredict: yolo task={task} mode=predict model={f[-1]} {s}" |
|
|
|
|
f"\nValidate: yolo task={task} mode=val model={f[-1]} {s}" |
|
|
|
|
f"\nPredict: yolo task={model.task} mode=predict model={f[-1]} {s}" |
|
|
|
|
f"\nValidate: yolo task={model.task} mode=val model={f[-1]} {s}" |
|
|
|
|
f"\nVisualize: https://netron.app") |
|
|
|
|
|
|
|
|
|
self.run_callbacks("on_export_end") |
|
|
|
@ -375,9 +384,13 @@ class Exporter: |
|
|
|
|
LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') |
|
|
|
|
f = self.file.with_suffix('.mlmodel') |
|
|
|
|
|
|
|
|
|
task = self.model.task |
|
|
|
|
model = iOSModel(self.model, self.im).eval() if self.args.nms else self.model |
|
|
|
|
ts = torch.jit.trace(model, self.im, strict=False) # TorchScript model |
|
|
|
|
ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=self.im.shape, scale=1 / 255, bias=[0, 0, 0])]) |
|
|
|
|
classifier_config = ct.ClassifierConfig(list(model.names.values())) if task == 'classify' else None |
|
|
|
|
ct_model = ct.convert(ts, |
|
|
|
|
inputs=[ct.ImageType('image', shape=self.im.shape, scale=1 / 255, bias=[0, 0, 0])], |
|
|
|
|
classifier_config=classifier_config) |
|
|
|
|
bits, mode = (8, 'kmeans_lut') if self.args.int8 else (16, 'linear') if self.args.half else (32, None) |
|
|
|
|
if bits < 32: |
|
|
|
|
if MACOS: # quantization only supported on macOS |
|
|
|
@ -387,6 +400,10 @@ class Exporter: |
|
|
|
|
if self.args.nms: |
|
|
|
|
ct_model = self._pipeline_coreml(ct_model) |
|
|
|
|
|
|
|
|
|
ct_model.short_description = self.metadata['description'] |
|
|
|
|
ct_model.author = self.metadata['author'] |
|
|
|
|
ct_model.license = self.metadata['license'] |
|
|
|
|
ct_model.version = self.metadata['version'] |
|
|
|
|
ct_model.save(str(f)) |
|
|
|
|
return f, ct_model |
|
|
|
|
|
|
|
|
@ -687,8 +704,8 @@ class Exporter: |
|
|
|
|
out0_shape = out[out0.name].shape |
|
|
|
|
out1_shape = out[out1.name].shape |
|
|
|
|
else: # linux and windows can not run model.predict(), get sizes from pytorch output y |
|
|
|
|
out0_shape = self.output_shape[1], self.output_shape[2] - 5 # (3780, 80) |
|
|
|
|
out1_shape = self.output_shape[1], 4 # (3780, 4) |
|
|
|
|
out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80) |
|
|
|
|
out1_shape = self.output_shape[2], 4 # (3780, 4) |
|
|
|
|
|
|
|
|
|
# Checks |
|
|
|
|
names = self.metadata['names'] |
|
|
|
@ -714,7 +731,7 @@ class Exporter: |
|
|
|
|
# flexible_shape_utils.update_image_size_range(spec, feature_name='image', size_range=r) |
|
|
|
|
|
|
|
|
|
# Print |
|
|
|
|
print(spec.description) |
|
|
|
|
# print(spec.description) |
|
|
|
|
|
|
|
|
|
# Model from spec |
|
|
|
|
model = ct.models.MLModel(spec) |
|
|
|
@ -771,10 +788,6 @@ class Exporter: |
|
|
|
|
|
|
|
|
|
# Update metadata |
|
|
|
|
pipeline.spec.specificationVersion = 5 |
|
|
|
|
pipeline.spec.description.metadata.versionString = f'Ultralytics YOLOv{ultralytics.__version__}' |
|
|
|
|
pipeline.spec.description.metadata.shortDescription = f'Ultralytics {self.pretty_name} CoreML model' |
|
|
|
|
pipeline.spec.description.metadata.author = 'Ultralytics (https://ultralytics.com)' |
|
|
|
|
pipeline.spec.description.metadata.license = 'GPL-3.0 license (https://ultralytics.com/license)' |
|
|
|
|
pipeline.spec.description.metadata.userDefined.update({ |
|
|
|
|
'IoU threshold': str(nms.iouThreshold), |
|
|
|
|
'Confidence threshold': str(nms.confidenceThreshold)}) |
|
|
|
|