Support CoreML NMS export for Segment, Pose and OBB (#19173)

Signed-off-by: Mohammed Yasin <32206511+Y-T-G@users.noreply.github.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
pull/19273/head
Mohammed Yasin 4 weeks ago committed by GitHub
parent d92ab8764b
commit 0ae4670da6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 3
      docs/en/macros/export-table.md
  2. 4
      docs/en/reference/engine/exporter.md
  3. 13
      tests/test_exports.py
  4. 152
      ultralytics/engine/exporter.py
  5. 10
      ultralytics/nn/autobackend.py
  6. 5
      ultralytics/utils/ops.py

@ -1,5 +1,4 @@
{%set tip1 = ':material-information-outline:{ title="conf, iou, agnostic_nms are also available when nms=True" }' %}
{%set tip2 = ':material-information-outline:{ title="conf, iou are also available when nms=True" }' %}
| Format | `format` Argument | Model | Metadata | Arguments |
| ------------------------------------------------- | ----------------- | ----------------------------------------------- | -------- | --------------------------------------------------------------------------------------------- |
@ -8,7 +7,7 @@
| [ONNX](../integrations/onnx.md) | `onnx` | `{{ model_name or "yolo11n" }}.onnx` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `opset`, `nms`{{ tip1 }}, `batch` |
| [OpenVINO](../integrations/openvino.md) | `openvino` | `{{ model_name or "yolo11n" }}_openvino_model/` | ✅ | `imgsz`, `half`, `dynamic`, `int8`, `nms`{{ tip1 }}, `batch`, `data` |
| [TensorRT](../integrations/tensorrt.md) | `engine` | `{{ model_name or "yolo11n" }}.engine` | ✅ | `imgsz`, `half`, `dynamic`, `simplify`, `workspace`, `int8`, `nms`{{ tip1 }}, `batch`, `data` |
| [CoreML](../integrations/coreml.md) | `coreml` | `{{ model_name or "yolo11n" }}.mlpackage` | ✅ | `imgsz`, `half`, `int8`, `nms`{{ tip2 }}, `batch` |
| [CoreML](../integrations/coreml.md) | `coreml` | `{{ model_name or "yolo11n" }}.mlpackage` | ✅ | `imgsz`, `half`, `int8`, `nms`{{ tip1 }}, `batch` |
| [TF SavedModel](../integrations/tf-savedmodel.md) | `saved_model` | `{{ model_name or "yolo11n" }}_saved_model/` | ✅ | `imgsz`, `keras`, `int8`, `nms`{{ tip1 }}, `batch` |
| [TF GraphDef](../integrations/tf-graphdef.md) | `pb` | `{{ model_name or "yolo11n" }}.pb` | ❌ | `imgsz`, `batch` |
| [TF Lite](../integrations/tflite.md) | `tflite` | `{{ model_name or "yolo11n" }}.tflite` | ✅ | `imgsz`, `half`, `int8`, `nms`{{ tip1 }}, `batch`, `data` |

@ -15,10 +15,6 @@ keywords: YOLOv8, export formats, ONNX, TensorRT, CoreML, machine learning model
<br><br><hr><br>
## ::: ultralytics.engine.exporter.IOSDetectModel
<br><br><hr><br>
## ::: ultralytics.engine.exporter.NMSModel
<br><br><hr><br>

@ -116,14 +116,16 @@ def test_export_torchscript_matrix(task, dynamic, int8, half, batch, nms):
@pytest.mark.skipif(not TORCH_1_9, reason="CoreML>=7.2 not supported with PyTorch<=1.8")
@pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="CoreML not supported in Python 3.12")
@pytest.mark.parametrize(
"task, dynamic, int8, half, batch",
"task, dynamic, int8, half, batch, nms",
[ # generate all combinations except for exclusion cases
(task, dynamic, int8, half, batch)
for task, dynamic, int8, half, batch in product(TASKS, [False], [True, False], [True, False], [1])
if not (int8 and half)
(task, dynamic, int8, half, batch, nms)
for task, dynamic, int8, half, batch, nms in product(
TASKS, [False], [True, False], [True, False], [1], [True, False]
)
if not ((int8 and half) or (task == "classify" and nms))
],
)
def test_export_coreml_matrix(task, dynamic, int8, half, batch):
def test_export_coreml_matrix(task, dynamic, int8, half, batch, nms):
"""Test YOLO exports to CoreML format with various parameter configurations."""
file = YOLO(TASK2MODEL[task]).export(
format="coreml",
@ -132,6 +134,7 @@ def test_export_coreml_matrix(task, dynamic, int8, half, batch):
int8=int8,
half=half,
batch=batch,
nms=nms,
)
YOLO(file)([SOURCE] * batch, imgsz=32) # exported model inference at batch=3
shutil.rmtree(file) # cleanup

@ -84,7 +84,6 @@ from ultralytics.utils import (
LINUX,
LOGGER,
MACOS,
PYTHON_VERSION,
RKNN_CHIPS,
ROOT,
WINDOWS,
@ -356,7 +355,7 @@ class Exporter:
y = None
for _ in range(2): # dry runs
y = NMSModel(model, self.args)(im) if self.args.nms and not coreml else model(im)
y = NMSModel(model, self.args)(im) if self.args.nms else model(im)
if self.args.half and onnx and self.device.type != "cpu":
im, model = im.half(), model.half() # to FP16
@ -766,12 +765,9 @@ class Exporter:
if self.model.task == "classify":
classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None
model = self.model
elif self.model.task == "detect":
model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model
elif self.args.nms:
model = NMSModel(self.model, self.args)
else:
if self.args.nms:
LOGGER.warning(f"{prefix} WARNING ⚠ 'nms=True' is only available for Detect models like 'yolo11n.pt'.")
# TODO CoreML Segment and Pose model pipelining
model = self.model
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model
@ -793,15 +789,6 @@ class Exporter:
op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512)
config = cto.OptimizationConfig(global_config=op_config)
ct_model = cto.palettize_weights(ct_model, config=config)
if self.args.nms and self.model.task == "detect":
if mlmodel:
# coremltools<=6.2 NMS export requires Python<3.11
check_version(PYTHON_VERSION, "<3.11", name="Python ", hard=True)
weights_dir = None
else:
ct_model.save(str(f)) # save otherwise weights_dir does not exist
weights_dir = str(f / "Data/com.apple.CoreML/weights")
ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir)
m = self.metadata # metadata dict
ct_model.short_description = m.pop("description")
@ -1391,112 +1378,6 @@ class Exporter:
populator.populate()
tmp_file.unlink()
def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")):
"""YOLO CoreML pipeline."""
import coremltools as ct # noqa
LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...")
_, _, h, w = list(self.im.shape) # BCHW
# Output shapes
spec = model.get_spec()
out0, out1 = iter(spec.description.output)
if MACOS:
from PIL import Image
img = Image.new("RGB", (w, h)) # w=192, h=320
out = model.predict({"image": img})
out0_shape = out[out0.name].shape # (3780, 80)
out1_shape = out[out1.name].shape # (3780, 4)
else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y
out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80)
out1_shape = self.output_shape[2], 4 # (3780, 4)
# Checks
names = self.metadata["names"]
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height
_, nc = out0_shape # number of anchors, number of classes
assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check
# Define output shapes (missing)
out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80)
out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4)
# Model from spec
model = ct.models.MLModel(spec, weights_dir=weights_dir)
# 3. Create NMS protobuf
nms_spec = ct.proto.Model_pb2.Model()
nms_spec.specificationVersion = 5
for i in range(2):
decoder_output = model._spec.description.output[i].SerializeToString()
nms_spec.description.input.add()
nms_spec.description.input[i].ParseFromString(decoder_output)
nms_spec.description.output.add()
nms_spec.description.output[i].ParseFromString(decoder_output)
nms_spec.description.output[0].name = "confidence"
nms_spec.description.output[1].name = "coordinates"
output_sizes = [nc, 4]
for i in range(2):
ma_type = nms_spec.description.output[i].type.multiArrayType
ma_type.shapeRange.sizeRanges.add()
ma_type.shapeRange.sizeRanges[0].lowerBound = 0
ma_type.shapeRange.sizeRanges[0].upperBound = -1
ma_type.shapeRange.sizeRanges.add()
ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i]
ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i]
del ma_type.shape[:]
nms = nms_spec.nonMaximumSuppression
nms.confidenceInputFeatureName = out0.name # 1x507x80
nms.coordinatesInputFeatureName = out1.name # 1x507x4
nms.confidenceOutputFeatureName = "confidence"
nms.coordinatesOutputFeatureName = "coordinates"
nms.iouThresholdInputFeatureName = "iouThreshold"
nms.confidenceThresholdInputFeatureName = "confidenceThreshold"
nms.iouThreshold = self.args.iou
nms.confidenceThreshold = self.args.conf
nms.pickTop.perClass = True
nms.stringClassLabels.vector.extend(names.values())
nms_model = ct.models.MLModel(nms_spec)
# 4. Pipeline models together
pipeline = ct.models.pipeline.Pipeline(
input_features=[
("image", ct.models.datatypes.Array(3, ny, nx)),
("iouThreshold", ct.models.datatypes.Double()),
("confidenceThreshold", ct.models.datatypes.Double()),
],
output_features=["confidence", "coordinates"],
)
pipeline.add_model(model)
pipeline.add_model(nms_model)
# Correct datatypes
pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString())
pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString())
pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString())
# Update metadata
pipeline.spec.specificationVersion = 5
pipeline.spec.description.metadata.userDefined.update(
{"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)}
)
# Save the model
model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
model.input_description["image"] = "Input image"
model.input_description["iouThreshold"] = f"(optional) IoU threshold override (default: {nms.iouThreshold})"
model.input_description["confidenceThreshold"] = (
f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})"
)
model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")'
model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)"
LOGGER.info(f"{prefix} pipeline success")
return model
def add_callback(self, event: str, callback):
"""Appends the given callback."""
self.callbacks[event].append(callback)
@ -1507,26 +1388,6 @@ class Exporter:
callback(self)
class IOSDetectModel(torch.nn.Module):
"""Wrap an Ultralytics YOLO model for Apple iOS CoreML export."""
def __init__(self, model, im):
"""Initialize the IOSDetectModel class with a YOLO model and example image."""
super().__init__()
_, _, h, w = im.shape # batch, channel, height, width
self.model = model
self.nc = len(model.names) # number of classes
if w == h:
self.normalize = 1.0 / w # scalar
else:
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller)
def forward(self, x):
"""Normalize predictions of object detection model with input size-dependent factors."""
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1)
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4)
class NMSModel(torch.nn.Module):
"""Model wrapper with embedded NMS for Detect, Segment, Pose and OBB."""
@ -1585,7 +1446,8 @@ class NMSModel(torch.nn.Module):
box = xywh2xyxy(box)
if self.is_tf:
# TFlite bug returns less boxes
box = torch.nn.functional.pad(box, (0, 0, 0, mask.shape[0] - box.shape[0]))
pad = torch.zeros((mask.shape[0] - box.shape[0], box.shape[-1]), device=box.device, dtype=box.dtype)
box = torch.cat((box, pad))
nmsbox = box.clone()
# `8` is the minimum value experimented to get correct NMS results for obb
multiplier = 8 if self.obb else 1
@ -1622,6 +1484,6 @@ class NMSModel(torch.nn.Module):
[box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1).to(out.dtype), extra[keep]], dim=-1
)
# Zero-pad to max_det size to avoid reshape error
pad = (0, 0, 0, self.args.max_det - dets.shape[0])
out[i] = torch.nn.functional.pad(dets, pad)
pad = torch.zeros((self.args.max_det - dets.shape[0], out.shape[-1]), device=out.device, dtype=out.dtype)
out[i] = torch.cat((dets, pad))
return (out, preds[1]) if self.model.task == "segment" else out

@ -640,14 +640,10 @@ class AutoBackend(nn.Module):
y = self.model.predict({"image": im_pil}) # coordinates are xywh normalized
if "confidence" in y:
raise TypeError(
"Ultralytics only supports inference of non-pipelined CoreML models exported with "
f"'nms=False', but 'model={w}' has an NMS pipeline created by an 'nms=True' export."
"'model={w}' has an NMS pipeline created by an older version of Ultralytics. "
"CoreML inference with NMS is only supported for models exported with latest Ultralytics. "
"You may export the model again with latest Ultralytics to resolve this."
)
# TODO: CoreML NMS inference handling
# from ultralytics.utils.ops import xywh2xyxy
# box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
# conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float32)
# y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
y = list(y.values())
if len(y) == 2 and len(y[1].shape) != 4: # segmentation model
y = list(reversed(y)) # reversed for segmentation models (pred, proto)

@ -441,12 +441,9 @@ def xywh2xyxy(x):
y (np.ndarray | torch.Tensor): The bounding box coordinates in (x1, y1, x2, y2) format.
"""
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = empty_like(x) # faster than clone/copy
xy = x[..., :2] # centers
wh = x[..., 2:] / 2 # half width-height
y[..., :2] = xy - wh # top left xy
y[..., 2:] = xy + wh # bottom right xy
return y
return (np.concatenate if isinstance(x, np.ndarray) else torch.cat)((xy - wh, xy + wh), -1)
def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):

Loading…
Cancel
Save