diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 33c724b518..1a8e8b9676 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -59,16 +59,17 @@ class Detect(nn.Module): cls = x_cat[:, self.reg_max * 4 :] else: box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) - dbox = self.decode_bboxes(box) if self.export and self.format in ("tflite", "edgetpu"): # Precompute normalization factor to increase numerical stability # See https://github.com/ultralytics/ultralytics/issues/7371 - img_h = shape[2] - img_w = shape[3] - img_size = torch.tensor([img_w, img_h, img_w, img_h], device=box.device).reshape(1, 4, 1) - norm = self.strides / (self.stride[0] * img_size) - dbox = dist2bbox(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2], xywh=True, dim=1) + grid_h = shape[2] + grid_w = shape[3] + grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) + norm = self.strides / (self.stride[0] * grid_size) + dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) + else: + dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides y = torch.cat((dbox, cls.sigmoid()), 1) return y if self.export else (y, x) @@ -82,9 +83,9 @@ class Detect(nn.Module): a[-1].bias.data[:] = 1.0 # box b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) - def decode_bboxes(self, bboxes): + def decode_bboxes(self, bboxes, anchors): """Decode bounding boxes.""" - return dist2bbox(self.dfl(bboxes), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides + return dist2bbox(bboxes, anchors, xywh=True, dim=1) class Segment(Detect): @@ -139,9 +140,9 @@ class OBB(Detect): return x, angle return torch.cat([x, angle], 1) if self.export else (torch.cat([x[0], angle], 1), (x[1], angle)) - def decode_bboxes(self, bboxes): + def decode_bboxes(self, bboxes, anchors): """Decode rotated bounding boxes.""" - return dist2rbox(self.dfl(bboxes), self.angle, self.anchors.unsqueeze(0), dim=1) * self.strides + return dist2rbox(bboxes, self.angle, anchors, dim=1) class Pose(Detect):