From f619f55a80581dbb8a36dc87a66217d32b997ad3 Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Thu, 12 Sep 2024 16:41:55 +0800 Subject: [PATCH 1/3] Update head.py --- ultralytics/nn/modules/head.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index be0a8698fb..52dffddc18 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -89,13 +89,7 @@ class Detect(nn.Module): # Inference path shape = x[0].shape # BCHW x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) - if self.export and self.format == "mct": - self.anchors, self.strides = ( - x.transpose(0, 1) - for x in make_anchors(getattr(self, "feats_size", torch.Tensor([80, 40, 20])), self.stride, 0.5) - ) - self.anchors, self.strides = self.anchors.to(x[0].device), self.strides.to(x[0].device) - elif self.dynamic or self.shape != shape: + if self.format != "mct" and (self.dynamic or self.shape != shape): self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) self.shape = shape From 0d45fd8ef4c5e927e9b4b503941a8de405990d92 Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Thu, 12 Sep 2024 16:59:03 +0800 Subject: [PATCH 2/3] Update exporter.py --- ultralytics/engine/exporter.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/ultralytics/engine/exporter.py b/ultralytics/engine/exporter.py index 6e9036659e..f35043aee3 100644 --- a/ultralytics/engine/exporter.py +++ b/ultralytics/engine/exporter.py @@ -268,6 +268,20 @@ class Exporter: elif isinstance(m, C2f) and not is_tf_format: # EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph m.forward = m.forward_split + if isinstance(m, Detect) and mct: + from ultralytics.utils.tal import make_anchors + + anchors, strides = ( + x.transpose(0, 1) + for x in make_anchors( + torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), + m.stride, + 0.5, + ) + ) + m.anchors = anchors + m.strides = strides + if isinstance(m, C2f) and mct: m.forward = m.forward_fx @@ -1295,9 +1309,9 @@ class Exporter: model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir) model.input_description["image"] = "Input image" model.input_description["iouThreshold"] = f"(optional) IoU threshold override (default: {nms.iouThreshold})" - model.input_description["confidenceThreshold"] = ( - f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})" - ) + model.input_description[ + "confidenceThreshold" + ] = f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})" model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")' model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)" LOGGER.info(f"{prefix} pipeline success") From 4ded1eca39780904f0c542118e08728b9b8ea515 Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Thu, 12 Sep 2024 16:59:09 +0800 Subject: [PATCH 3/3] Update tal.py --- ultralytics/utils/tal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py index f41c32bfd3..9fb5020923 100644 --- a/ultralytics/utils/tal.py +++ b/ultralytics/utils/tal.py @@ -306,7 +306,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5): assert feats is not None dtype, device = feats[0].dtype, feats[0].device for i, stride in enumerate(strides): - h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i]), int(feats[i])) + h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1])) sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)