From 9d70ad09d02be3b17b6798c50eef05e6702a6eb4 Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Tue, 23 Jul 2024 17:36:50 +0800 Subject: [PATCH] update --- ultralytics/nn/modules/head.py | 156 +++++++++++++++++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 3062ac07b5..52e5a51ae1 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -18,6 +18,162 @@ from .utils import bias_init_with_prob, linear_init __all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect" +class DetectNEW(nn.Module): + """YOLOv8 Detect head for detection models.""" + + dynamic = False # force grid reconstruction + export = False # export mode + end2end = False # end2end + max_det = 300 # max_det + shape = None + anchors = torch.empty(0) # init + strides = torch.empty(0) # init + + def __init__(self, nc=80, ch=()): + """Initializes the YOLOv8 detection layer with specified number of classes and channels.""" + super().__init__() + self.nc = nc # number of classes + self.nl = len(ch) # number of detection layers + self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x) + self.no = nc + self.reg_max * 4 # number of outputs per anchor + self.stride = torch.zeros(self.nl) # strides computed during build + c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels + self.cv2 = nn.ModuleList( + nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch + ) + self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) + self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() + + if self.end2end: + self.one2one_cv2 = copy.deepcopy(self.cv2) + self.one2one_cv3 = copy.deepcopy(self.cv3) + + self.relu = nn.ReLU() + + def forward(self, x): + """Concatenates and returns predicted bounding boxes and class probabilities.""" + if self.end2end: + return self.forward_end2end(x) + + for i in range(self.nl): + x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) + if self.training: # Training path + return x + y = self._inference(x) + return y if self.export else (y, x) + + def forward_end2end(self, x): + """ + Performs forward pass of the v10Detect module. + + Args: + x (tensor): Input tensor. + + Returns: + (dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections. + If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately. + """ + x_detach = [xi.detach() for xi in x] + one2one = [ + torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl) + ] + for i in range(self.nl): + x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) + if self.training: # Training path + return {"one2many": x, "one2one": one2one} + + y = self._inference(one2one) + y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc) + return y if self.export else (y, {"one2many": x, "one2one": one2one}) + + def _inference(self, x): + """Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.""" + # Inference path + shape = x[0].shape # BCHW + x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) + # if self.dynamic or self.shape != shape: + self.anchors, self.strides = ( + x.transpose(0, 1) for x in make_anchors(torch.Tensor([80, 40, 20]).cuda(), self.stride, 0.5) + ) + self.shape = shape + self.strides /= 640 + + if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops + box = x_cat[:, : self.reg_max * 4] + cls = x_cat[:, self.reg_max * 4 :] + else: + box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) + + if self.export and self.format in {"tflite", "edgetpu"}: + # Precompute normalization factor to increase numerical stability + # See https://github.com/ultralytics/ultralytics/issues/7371 + grid_h = shape[2] + grid_w = shape[3] + grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) + norm = self.strides / (self.stride[0] * grid_size) + dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) + else: + dbox = self.decode_bboxes(self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides) + + y1 = self.relu(dbox[:, 0, :]) + x1 = self.relu(dbox[:, 1, :]) + y2 = self.relu(dbox[:, 2, :]) + x2 = self.relu(dbox[:, 3, :]) + y_bb = torch.stack((x1, y1, x2, y2), 1).transpose(1, 2) + + return y_bb, cls.sigmoid().permute(0, 2, 1) + + def bias_init(self): + """Initialize Detect() biases, WARNING: requires stride availability.""" + m = self # self.model[-1] # Detect() module + # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1 + # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency + for a, b, s in zip(m.cv2, m.cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + if self.end2end: + for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from + a[-1].bias.data[:] = 1.0 # box + b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) + + def decode_bboxes(self, bboxes, anchors): + """Decode bounding boxes.""" + return dist2bbox(bboxes, anchors, xywh=False, dim=1) + + @staticmethod + def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80): + """ + Post-processes the predictions obtained from a YOLOv10 model. + + Args: + preds (torch.Tensor): The predictions obtained from the model. It should have a shape of (batch_size, num_boxes, 4 + num_classes). + max_det (int): The maximum number of detections to keep. + nc (int, optional): The number of classes. Defaults to 80. + + Returns: + (torch.Tensor): The post-processed predictions with shape (batch_size, max_det, 6), + including bounding boxes, scores and cls. + """ + assert 4 + nc == preds.shape[-1] + boxes, scores = preds.split([4, nc], dim=-1) + max_scores = scores.amax(dim=-1) + max_scores, index = torch.topk(max_scores, min(max_det, max_scores.shape[1]), axis=-1) + index = index.unsqueeze(-1) + boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1])) + scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1])) + + # NOTE: simplify but result slightly lower mAP + # scores, labels = scores.max(dim=-1) + # return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) + + scores, index = torch.topk(scores.flatten(1), max_det, axis=-1) + labels = index % nc + index = index // nc + boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) + + return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1) + + class Detect(nn.Module): def __init__(self, nc: int = 80, ch=()): """