From 72184613a2e7c308f477d89a4ab9a78d46a6b2d5 Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Thu, 1 Aug 2024 17:22:41 +0800 Subject: [PATCH] Update head.py --- ultralytics/nn/modules/head.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 0b75f4d77c..c06b0fa98b 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -91,13 +91,14 @@ class Detect(nn.Module): # Inference path shape = x[0].shape # BCHW x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) - # if self.dynamic or self.shape != shape: - self.anchors, self.strides = ( - x.transpose(0, 1) for x in make_anchors(torch.Tensor([80, 40, 20]).cuda(), self.stride, 0.5) - ) - self.shape = shape if self.export and self.format == "sony": + self.anchors, self.strides = ( + x.transpose(0, 1) for x in make_anchors(torch.Tensor([80, 40, 20]).cuda(), self.stride, 0.5) + ) self.strides /= 640 # NOTE: the relu could be removed in the future. + elif self.dynamic or self.shape != shape: + self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) + self.shape = shape if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops box = x_cat[:, : self.reg_max * 4] @@ -114,7 +115,7 @@ class Detect(nn.Module): norm = self.strides / (self.stride[0] * grid_size) dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) elif self.export and self.format == "sony": - dbox = self.decode_bboxes(self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides) + dbox = self.decode_bboxes(self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides, xywh=False) # NOTE: the relu could be removed in the future. y1 = self.relu(dbox[:, 0, :]) x1 = self.relu(dbox[:, 1, :]) @@ -140,9 +141,9 @@ class Detect(nn.Module): a[-1].bias.data[:] = 1.0 # box b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) - def decode_bboxes(self, bboxes, anchors): + def decode_bboxes(self, bboxes, anchors, xywh=True): """Decode bounding boxes.""" - return dist2bbox(bboxes, anchors, xywh=False, dim=1) + return dist2bbox(bboxes, anchors, xywh=xywh and (not self.end2end), dim=1) @staticmethod def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):