|
|
|
@ -89,13 +89,7 @@ class Detect(nn.Module): |
|
|
|
|
# Inference path |
|
|
|
|
shape = x[0].shape # BCHW |
|
|
|
|
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) |
|
|
|
|
if self.export and self.format == "mct": |
|
|
|
|
self.anchors, self.strides = ( |
|
|
|
|
x.transpose(0, 1) |
|
|
|
|
for x in make_anchors(getattr(self, "feats_size", torch.Tensor([80, 40, 20])), self.stride, 0.5) |
|
|
|
|
) |
|
|
|
|
self.anchors, self.strides = self.anchors.to(x[0].device), self.strides.to(x[0].device) |
|
|
|
|
elif self.dynamic or self.shape != shape: |
|
|
|
|
if self.format != "mct" and (self.dynamic or self.shape != shape): |
|
|
|
|
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) |
|
|
|
|
self.shape = shape |
|
|
|
|
|
|
|
|
|