mct-2.1.1
Laughing-q 3 months ago
commit a00e71a17a
  1. 20
      ultralytics/engine/exporter.py
  2. 8
      ultralytics/nn/modules/head.py
  3. 2
      ultralytics/utils/tal.py

@ -268,6 +268,20 @@ class Exporter:
elif isinstance(m, C2f) and not is_tf_format:
# EdgeTPU does not support FlexSplitV while split provides cleaner ONNX graph
m.forward = m.forward_split
if isinstance(m, Detect) and mct:
from ultralytics.utils.tal import make_anchors
anchors, strides = (
x.transpose(0, 1)
for x in make_anchors(
torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1),
m.stride,
0.5,
)
)
m.anchors = anchors
m.strides = strides
if isinstance(m, C2f) and mct:
m.forward = m.forward_fx
@ -1295,9 +1309,9 @@ class Exporter:
model = ct.models.MLModel(pipeline.spec, weights_dir=weights_dir)
model.input_description["image"] = "Input image"
model.input_description["iouThreshold"] = f"(optional) IoU threshold override (default: {nms.iouThreshold})"
model.input_description["confidenceThreshold"] = (
f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})"
)
model.input_description[
"confidenceThreshold"
] = f"(optional) Confidence threshold override (default: {nms.confidenceThreshold})"
model.output_description["confidence"] = 'Boxes × Class confidence (see user-defined metadata "classes")'
model.output_description["coordinates"] = "Boxes × [x, y, width, height] (relative to image size)"
LOGGER.info(f"{prefix} pipeline success")

@ -89,13 +89,7 @@ class Detect(nn.Module):
# Inference path
shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.export and self.format == "mct":
self.anchors, self.strides = (
x.transpose(0, 1)
for x in make_anchors(getattr(self, "feats_size", torch.Tensor([80, 40, 20])), self.stride, 0.5)
)
self.anchors, self.strides = self.anchors.to(x[0].device), self.strides.to(x[0].device)
elif self.dynamic or self.shape != shape:
if self.format != "mct" and (self.dynamic or self.shape != shape):
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape

@ -306,7 +306,7 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
assert feats is not None
dtype, device = feats[0].dtype, feats[0].device
for i, stride in enumerate(strides):
h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i]), int(feats[i]))
h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)

Loading…
Cancel
Save