|
|
|
@ -84,7 +84,6 @@ from ultralytics.utils import ( |
|
|
|
|
LINUX, |
|
|
|
|
LOGGER, |
|
|
|
|
MACOS, |
|
|
|
|
PYTHON_VERSION, |
|
|
|
|
RKNN_CHIPS, |
|
|
|
|
ROOT, |
|
|
|
|
WINDOWS, |
|
|
|
@ -356,7 +355,7 @@ class Exporter: |
|
|
|
|
|
|
|
|
|
y = None |
|
|
|
|
for _ in range(2): # dry runs |
|
|
|
|
y = NMSModel(model, self.args)(im) if self.args.nms and not coreml else model(im) |
|
|
|
|
y = NMSModel(model, self.args)(im) if self.args.nms else model(im) |
|
|
|
|
if self.args.half and onnx and self.device.type != "cpu": |
|
|
|
|
im, model = im.half(), model.half() # to FP16 |
|
|
|
|
|
|
|
|
@ -766,12 +765,9 @@ class Exporter: |
|
|
|
|
if self.model.task == "classify": |
|
|
|
|
classifier_config = ct.ClassifierConfig(list(self.model.names.values())) if self.args.nms else None |
|
|
|
|
model = self.model |
|
|
|
|
elif self.model.task == "detect": |
|
|
|
|
model = IOSDetectModel(self.model, self.im) if self.args.nms else self.model |
|
|
|
|
elif self.args.nms: |
|
|
|
|
model = NMSModel(self.model, self.args) |
|
|
|
|
else: |
|
|
|
|
if self.args.nms: |
|
|
|
|
LOGGER.warning(f"{prefix} WARNING ⚠️ 'nms=True' is only available for Detect models like 'yolo11n.pt'.") |
|
|
|
|
# TODO CoreML Segment and Pose model pipelining |
|
|
|
|
model = self.model |
|
|
|
|
|
|
|
|
|
ts = torch.jit.trace(model.eval(), self.im, strict=False) # TorchScript model |
|
|
|
@ -793,15 +789,6 @@ class Exporter: |
|
|
|
|
op_config = cto.OpPalettizerConfig(mode="kmeans", nbits=bits, weight_threshold=512) |
|
|
|
|
config = cto.OptimizationConfig(global_config=op_config) |
|
|
|
|
ct_model = cto.palettize_weights(ct_model, config=config) |
|
|
|
|
if self.args.nms and self.model.task == "detect": |
|
|
|
|
if mlmodel: |
|
|
|
|
# coremltools<=6.2 NMS export requires Python<3.11 |
|
|
|
|
check_version(PYTHON_VERSION, "<3.11", name="Python ", hard=True) |
|
|
|
|
weights_dir = None |
|
|
|
|
else: |
|
|
|
|
ct_model.save(str(f)) # save otherwise weights_dir does not exist |
|
|
|
|
weights_dir = str(f / "Data/com.apple.CoreML/weights") |
|
|
|
|
ct_model = self._pipeline_coreml(ct_model, weights_dir=weights_dir) |
|
|
|
|
|
|
|
|
|
m = self.metadata # metadata dict |
|
|
|
|
ct_model.short_description = m.pop("description") |
|
|
|
@ -1391,112 +1378,6 @@ class Exporter: |
|
|
|
|
populator.populate() |
|
|
|
|
tmp_file.unlink() |
|
|
|
|
|
|
|
|
|
def _pipeline_coreml(self, model, weights_dir=None, prefix=colorstr("CoreML Pipeline:")): |
|
|
|
|
"""YOLO CoreML pipeline.""" |
|
|
|
|
import coremltools as ct # noqa |
|
|
|
|
|
|
|
|
|
LOGGER.info(f"{prefix} starting pipeline with coremltools {ct.__version__}...") |
|
|
|
|
_, _, h, w = list(self.im.shape) # BCHW |
|
|
|
|
|
|
|
|
|
# Output shapes |
|
|
|
|
spec = model.get_spec() |
|
|
|
|
out0, out1 = iter(spec.description.output) |
|
|
|
|
if MACOS: |
|
|
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
img = Image.new("RGB", (w, h)) # w=192, h=320 |
|
|
|
|
out = model.predict({"image": img}) |
|
|
|
|
out0_shape = out[out0.name].shape # (3780, 80) |
|
|
|
|
out1_shape = out[out1.name].shape # (3780, 4) |
|
|
|
|
else: # linux and windows can not run model.predict(), get sizes from PyTorch model output y |
|
|
|
|
out0_shape = self.output_shape[2], self.output_shape[1] - 4 # (3780, 80) |
|
|
|
|
out1_shape = self.output_shape[2], 4 # (3780, 4) |
|
|
|
|
|
|
|
|
|
# Checks |
|
|
|
|
names = self.metadata["names"] |
|
|
|
|
nx, ny = spec.description.input[0].type.imageType.width, spec.description.input[0].type.imageType.height |
|
|
|
|
_, nc = out0_shape # number of anchors, number of classes |
|
|
|
|
assert len(names) == nc, f"{len(names)} names found for nc={nc}" # check |
|
|
|
|
|
|
|
|
|
# Define output shapes (missing) |
|
|
|
|
out0.type.multiArrayType.shape[:] = out0_shape # (3780, 80) |
|
|
|
|
out1.type.multiArrayType.shape[:] = out1_shape # (3780, 4) |
|
|
|
|
|
|
|
|
|
# Model from spec |
|
|
|
|
model = ct.models.MLModel(spec, weights_dir=weights_dir) |
|
|
|
|
|
|
|
|
|
# 3. Create NMS protobuf |
|
|
|
|
nms_spec = ct.proto.Model_pb2.Model() |
|
|
|
|
nms_spec.specificationVersion = 5 |
|
|
|
|
for i in range(2): |
|
|
|
|
decoder_output = model._spec.description.output[i].SerializeToString() |
|
|
|
|
nms_spec.description.input.add() |
|
|
|
|
nms_spec.description.input[i].ParseFromString(decoder_output) |
|
|
|
|
nms_spec.description.output.add() |
|
|
|
|
nms_spec.description.output[i].ParseFromString(decoder_output) |
|
|
|
|
|
|
|
|
|
nms_spec.description.output[0].name = "confidence" |
|
|
|
|
nms_spec.description.output[1].name = "coordinates" |
|
|
|
|
|
|
|
|
|
output_sizes = [nc, 4] |
|
|
|
|
for i in range(2): |
|
|
|
|
ma_type = nms_spec.description.output[i].type.multiArrayType |
|
|
|
|
ma_type.shapeRange.sizeRanges.add() |
|
|
|
|
ma_type.shapeRange.sizeRanges[0].lowerBound = 0 |
|
|
|
|
ma_type.shapeRange.sizeRanges[0].upperBound = -1 |
|
|
|
|
ma_type.shapeRange.sizeRanges.add() |
|
|
|
|
ma_type.shapeRange.sizeRanges[1].lowerBound = output_sizes[i] |
|
|
|
|
ma_type.shapeRange.sizeRanges[1].upperBound = output_sizes[i] |
|
|
|
|
del ma_type.shape[:] |
|
|
|
|
|
|
|
|
|
nms = nms_spec.nonMaximumSuppression |
|
|
|
|
nms.confidenceInputFeatureName = out0.name # 1x507x80 |
|
|
|
|
nms.coordinatesInputFeatureName = out1.name # 1x507x4 |
|
|
|
|
nms.confidenceOutputFeatureName = "confidence" |
|
|
|
|
nms.coordinatesOutputFeatureName = "coordinates" |
|
|
|
|
nms.iouThresholdInputFeatureName = "iouThreshold" |
|
|
|
|
nms.confidenceThresholdInputFeatureName = "confidenceThreshold" |
|
|
|
|
nms.iouThreshold = self.args.iou |
|
|
|
|
nms.confidenceThreshold = self.args.conf |
|
|
|
|
nms.pickTop.perClass = True |
|
|
|
|
nms.stringClassLabels.vector.extend(names.values()) |
|
|
|
|
nms_model = ct.models.MLModel(nms_spec) |
|
|
|
|
|
|
|
|
|
# 4. Pipeline models together |
|
|
|
|
pipeline = ct.models.pipeline.Pipeline( |
|
|
|
|
input_features=[ |
|
|
|
|
("image", ct.models.datatypes.Array(3, ny, nx)), |
|
|
|
|
("iouThreshold", ct.models.datatypes.Double()), |
|
|
|
|
("confidenceThreshold", ct.models.datatypes.Double()), |
|
|
|
|
], |
|
|
|
|
output_features=["confidence", "coordinates"], |
|
|
|
|
) |
|
|
|
|
pipeline.add_model(model) |
|
|
|
|
pipeline.add_model(nms_model) |
|
|
|
|
|
|
|
|
|
# Correct datatypes |
|
|
|
|
pipeline.spec.description.input[0].ParseFromString(model._spec.description.input[0].SerializeToString()) |
|
|
|
|
pipeline.spec.description.output[0].ParseFromString(nms_model._spec.description.output[0].SerializeToString()) |
|
|
|
|
pipeline.spec.description.output[1].ParseFromString(nms_model._spec.description.output[1].SerializeToString()) |
|
|
|
|
|
|
|
|
|
# Update metadata |
|
|
|
|
pipeline.spec.specificationVersion = 5 |
|
|
|
|
pipeline.spec.description.metadata.userDefined.update( |
|
|
|
|
{"IoU threshold": str(nms.iouThreshold), "Confidence threshold": str(nms.confidenceThreshold)} |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Save the model |
|
|
|
|
model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir) |
|
|
|
|
model.input_description["image"] = "Input image" |
|
|
|
|
model.input_description["iouThreshold"] = f"(optional) IoU threshold override (default: {nms.iouThreshold})" |
|
|
|
|
model.input_description["confidenceThreshold"] = ( |
|
|
|
|
f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})" |
|
|
|
|
) |
|
|
|
|
model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")' |
|
|
|
|
model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)" |
|
|
|
|
LOGGER.info(f"{prefix} pipeline success") |
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
def add_callback(self, event: str, callback): |
|
|
|
|
"""Appends the given callback.""" |
|
|
|
|
self.callbacks[event].append(callback) |
|
|
|
@ -1507,26 +1388,6 @@ class Exporter: |
|
|
|
|
callback(self) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IOSDetectModel(torch.nn.Module): |
|
|
|
|
"""Wrap an Ultralytics YOLO model for Apple iOS CoreML export.""" |
|
|
|
|
|
|
|
|
|
def __init__(self, model, im): |
|
|
|
|
"""Initialize the IOSDetectModel class with a YOLO model and example image.""" |
|
|
|
|
super().__init__() |
|
|
|
|
_, _, h, w = im.shape # batch, channel, height, width |
|
|
|
|
self.model = model |
|
|
|
|
self.nc = len(model.names) # number of classes |
|
|
|
|
if w == h: |
|
|
|
|
self.normalize = 1.0 / w # scalar |
|
|
|
|
else: |
|
|
|
|
self.normalize = torch.tensor([1.0 / w, 1.0 / h, 1.0 / w, 1.0 / h]) # broadcast (slower, smaller) |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
"""Normalize predictions of object detection model with input size-dependent factors.""" |
|
|
|
|
xywh, cls = self.model(x)[0].transpose(0, 1).split((4, self.nc), 1) |
|
|
|
|
return cls, xywh * self.normalize # confidence (3780, 80), coordinates (3780, 4) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NMSModel(torch.nn.Module): |
|
|
|
|
"""Model wrapper with embedded NMS for Detect, Segment, Pose and OBB.""" |
|
|
|
|
|
|
|
|
@ -1585,7 +1446,8 @@ class NMSModel(torch.nn.Module): |
|
|
|
|
box = xywh2xyxy(box) |
|
|
|
|
if self.is_tf: |
|
|
|
|
# TFlite bug returns less boxes |
|
|
|
|
box = torch.nn.functional.pad(box, (0, 0, 0, mask.shape[0] - box.shape[0])) |
|
|
|
|
pad = torch.zeros((mask.shape[0] - box.shape[0], box.shape[-1]), device=box.device, dtype=box.dtype) |
|
|
|
|
box = torch.cat((box, pad)) |
|
|
|
|
nmsbox = box.clone() |
|
|
|
|
# `8` is the minimum value experimented to get correct NMS results for obb |
|
|
|
|
multiplier = 8 if self.obb else 1 |
|
|
|
@ -1622,6 +1484,6 @@ class NMSModel(torch.nn.Module): |
|
|
|
|
[box[keep], score[keep].view(-1, 1), cls[keep].view(-1, 1).to(out.dtype), extra[keep]], dim=-1 |
|
|
|
|
) |
|
|
|
|
# Zero-pad to max_det size to avoid reshape error |
|
|
|
|
pad = (0, 0, 0, self.args.max_det - dets.shape[0]) |
|
|
|
|
out[i] = torch.nn.functional.pad(dets, pad) |
|
|
|
|
pad = torch.zeros((self.args.max_det - dets.shape[0], out.shape[-1]), device=out.device, dtype=out.dtype) |
|
|
|
|
out[i] = torch.cat((dets, pad)) |
|
|
|
|
return (out, preds[1]) if self.model.task == "segment" else out |
|
|
|
|