mct-2.1.1
Laughing-q 7 months ago
parent 9560c868af
commit 9d70ad09d0
  1. 156
      ultralytics/nn/modules/head.py

@ -18,6 +18,162 @@ from .utils import bias_init_with_prob, linear_init
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect"
class DetectNEW(nn.Module):
"""YOLOv8 Detect head for detection models."""
dynamic = False # force grid reconstruction
export = False # export mode
end2end = False # end2end
max_det = 300 # max_det
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init
def __init__(self, nc=80, ch=()):
"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""
super().__init__()
self.nc = nc # number of classes
self.nl = len(ch) # number of detection layers
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = nc + self.reg_max * 4 # number of outputs per anchor
self.stride = torch.zeros(self.nl) # strides computed during build
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
)
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
if self.end2end:
self.one2one_cv2 = copy.deepcopy(self.cv2)
self.one2one_cv3 = copy.deepcopy(self.cv3)
self.relu = nn.ReLU()
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
if self.end2end:
return self.forward_end2end(x)
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
if self.training: # Training path
return x
y = self._inference(x)
return y if self.export else (y, x)
def forward_end2end(self, x):
"""
Performs forward pass of the v10Detect module.
Args:
x (tensor): Input tensor.
Returns:
(dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections.
If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately.
"""
x_detach = [xi.detach() for xi in x]
one2one = [
torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl)
]
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
if self.training: # Training path
return {"one2many": x, "one2one": one2one}
y = self._inference(one2one)
y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc)
return y if self.export else (y, {"one2many": x, "one2one": one2one})
def _inference(self, x):
"""Decode predicted bounding boxes and class probabilities based on multiple-level feature maps."""
# Inference path
shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
# if self.dynamic or self.shape != shape:
self.anchors, self.strides = (
x.transpose(0, 1) for x in make_anchors(torch.Tensor([80, 40, 20]).cuda(), self.stride, 0.5)
)
self.shape = shape
self.strides /= 640
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
box = x_cat[:, : self.reg_max * 4]
cls = x_cat[:, self.reg_max * 4 :]
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
if self.export and self.format in {"tflite", "edgetpu"}:
# Precompute normalization factor to increase numerical stability
# See https://github.com/ultralytics/ultralytics/issues/7371
grid_h = shape[2]
grid_w = shape[3]
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
norm = self.strides / (self.stride[0] * grid_size)
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
else:
dbox = self.decode_bboxes(self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides)
y1 = self.relu(dbox[:, 0, :])
x1 = self.relu(dbox[:, 1, :])
y2 = self.relu(dbox[:, 2, :])
x2 = self.relu(dbox[:, 3, :])
y_bb = torch.stack((x1, y1, x2, y2), 1).transpose(1, 2)
return y_bb, cls.sigmoid().permute(0, 2, 1)
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
if self.end2end:
for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
def decode_bboxes(self, bboxes, anchors):
"""Decode bounding boxes."""
return dist2bbox(bboxes, anchors, xywh=False, dim=1)
@staticmethod
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80):
"""
Post-processes the predictions obtained from a YOLOv10 model.
Args:
preds (torch.Tensor): The predictions obtained from the model. It should have a shape of (batch_size, num_boxes, 4 + num_classes).
max_det (int): The maximum number of detections to keep.
nc (int, optional): The number of classes. Defaults to 80.
Returns:
(torch.Tensor): The post-processed predictions with shape (batch_size, max_det, 6),
including bounding boxes, scores and cls.
"""
assert 4 + nc == preds.shape[-1]
boxes, scores = preds.split([4, nc], dim=-1)
max_scores = scores.amax(dim=-1)
max_scores, index = torch.topk(max_scores, min(max_det, max_scores.shape[1]), axis=-1)
index = index.unsqueeze(-1)
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1]))
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1]))
# NOTE: simplify but result slightly lower mAP
# scores, labels = scores.max(dim=-1)
# return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1)
scores, index = torch.topk(scores.flatten(1), max_det, axis=-1)
labels = index % nc
index = index // nc
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1]))
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)
class Detect(nn.Module):
def __init__(self, nc: int = 80, ch=()):
"""

Loading…
Cancel
Save