Update head.py

mct-2.1.1
Laughing-q 3 months ago
parent 785dd462c5
commit f619f55a80
  1. 8
      ultralytics/nn/modules/head.py

@ -89,13 +89,7 @@ class Detect(nn.Module):
# Inference path
shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.export and self.format == "mct":
self.anchors, self.strides = (
x.transpose(0, 1)
for x in make_anchors(getattr(self, "feats_size", torch.Tensor([80, 40, 20])), self.stride, 0.5)
)
self.anchors, self.strides = self.anchors.to(x[0].device), self.strides.to(x[0].device)
elif self.dynamic or self.shape != shape:
if self.format != "mct" and (self.dynamic or self.shape != shape):
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape

Loading…
Cancel
Save