|
|
|
@ -18,6 +18,162 @@ from .utils import bias_init_with_prob, linear_init |
|
|
|
|
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DetectNEW(nn.Module): |
|
|
|
|
"""YOLOv8 Detect head for detection models.""" |
|
|
|
|
|
|
|
|
|
dynamic = False # force grid reconstruction |
|
|
|
|
export = False # export mode |
|
|
|
|
end2end = False # end2end |
|
|
|
|
max_det = 300 # max_det |
|
|
|
|
shape = None |
|
|
|
|
anchors = torch.empty(0) # init |
|
|
|
|
strides = torch.empty(0) # init |
|
|
|
|
|
|
|
|
|
def __init__(self, nc=80, ch=()): |
|
|
|
|
"""Initializes the YOLOv8 detection layer with specified number of classes and channels.""" |
|
|
|
|
super().__init__() |
|
|
|
|
self.nc = nc # number of classes |
|
|
|
|
self.nl = len(ch) # number of detection layers |
|
|
|
|
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x) |
|
|
|
|
self.no = nc + self.reg_max * 4 # number of outputs per anchor |
|
|
|
|
self.stride = torch.zeros(self.nl) # strides computed during build |
|
|
|
|
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels |
|
|
|
|
self.cv2 = nn.ModuleList( |
|
|
|
|
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch |
|
|
|
|
) |
|
|
|
|
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) |
|
|
|
|
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() |
|
|
|
|
|
|
|
|
|
if self.end2end: |
|
|
|
|
self.one2one_cv2 = copy.deepcopy(self.cv2) |
|
|
|
|
self.one2one_cv3 = copy.deepcopy(self.cv3) |
|
|
|
|
|
|
|
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
"""Concatenates and returns predicted bounding boxes and class probabilities.""" |
|
|
|
|
if self.end2end: |
|
|
|
|
return self.forward_end2end(x) |
|
|
|
|
|
|
|
|
|
for i in range(self.nl): |
|
|
|
|
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) |
|
|
|
|
if self.training: # Training path |
|
|
|
|
return x |
|
|
|
|
y = self._inference(x) |
|
|
|
|
return y if self.export else (y, x) |
|
|
|
|
|
|
|
|
|
def forward_end2end(self, x): |
|
|
|
|
""" |
|
|
|
|
Performs forward pass of the v10Detect module. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
x (tensor): Input tensor. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(dict, tensor): If not in training mode, returns a dictionary containing the outputs of both one2many and one2one detections. |
|
|
|
|
If in training mode, returns a dictionary containing the outputs of one2many and one2one detections separately. |
|
|
|
|
""" |
|
|
|
|
x_detach = [xi.detach() for xi in x] |
|
|
|
|
one2one = [ |
|
|
|
|
torch.cat((self.one2one_cv2[i](x_detach[i]), self.one2one_cv3[i](x_detach[i])), 1) for i in range(self.nl) |
|
|
|
|
] |
|
|
|
|
for i in range(self.nl): |
|
|
|
|
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) |
|
|
|
|
if self.training: # Training path |
|
|
|
|
return {"one2many": x, "one2one": one2one} |
|
|
|
|
|
|
|
|
|
y = self._inference(one2one) |
|
|
|
|
y = self.postprocess(y.permute(0, 2, 1), self.max_det, self.nc) |
|
|
|
|
return y if self.export else (y, {"one2many": x, "one2one": one2one}) |
|
|
|
|
|
|
|
|
|
def _inference(self, x): |
|
|
|
|
"""Decode predicted bounding boxes and class probabilities based on multiple-level feature maps.""" |
|
|
|
|
# Inference path |
|
|
|
|
shape = x[0].shape # BCHW |
|
|
|
|
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2) |
|
|
|
|
# if self.dynamic or self.shape != shape: |
|
|
|
|
self.anchors, self.strides = ( |
|
|
|
|
x.transpose(0, 1) for x in make_anchors(torch.Tensor([80, 40, 20]).cuda(), self.stride, 0.5) |
|
|
|
|
) |
|
|
|
|
self.shape = shape |
|
|
|
|
self.strides /= 640 |
|
|
|
|
|
|
|
|
|
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops |
|
|
|
|
box = x_cat[:, : self.reg_max * 4] |
|
|
|
|
cls = x_cat[:, self.reg_max * 4 :] |
|
|
|
|
else: |
|
|
|
|
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) |
|
|
|
|
|
|
|
|
|
if self.export and self.format in {"tflite", "edgetpu"}: |
|
|
|
|
# Precompute normalization factor to increase numerical stability |
|
|
|
|
# See https://github.com/ultralytics/ultralytics/issues/7371 |
|
|
|
|
grid_h = shape[2] |
|
|
|
|
grid_w = shape[3] |
|
|
|
|
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) |
|
|
|
|
norm = self.strides / (self.stride[0] * grid_size) |
|
|
|
|
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2]) |
|
|
|
|
else: |
|
|
|
|
dbox = self.decode_bboxes(self.dfl(box) * self.strides, self.anchors.unsqueeze(0) * self.strides) |
|
|
|
|
|
|
|
|
|
y1 = self.relu(dbox[:, 0, :]) |
|
|
|
|
x1 = self.relu(dbox[:, 1, :]) |
|
|
|
|
y2 = self.relu(dbox[:, 2, :]) |
|
|
|
|
x2 = self.relu(dbox[:, 3, :]) |
|
|
|
|
y_bb = torch.stack((x1, y1, x2, y2), 1).transpose(1, 2) |
|
|
|
|
|
|
|
|
|
return y_bb, cls.sigmoid().permute(0, 2, 1) |
|
|
|
|
|
|
|
|
|
def bias_init(self): |
|
|
|
|
"""Initialize Detect() biases, WARNING: requires stride availability.""" |
|
|
|
|
m = self # self.model[-1] # Detect() module |
|
|
|
|
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1 |
|
|
|
|
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency |
|
|
|
|
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from |
|
|
|
|
a[-1].bias.data[:] = 1.0 # box |
|
|
|
|
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) |
|
|
|
|
if self.end2end: |
|
|
|
|
for a, b, s in zip(m.one2one_cv2, m.one2one_cv3, m.stride): # from |
|
|
|
|
a[-1].bias.data[:] = 1.0 # box |
|
|
|
|
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) |
|
|
|
|
|
|
|
|
|
def decode_bboxes(self, bboxes, anchors): |
|
|
|
|
"""Decode bounding boxes.""" |
|
|
|
|
return dist2bbox(bboxes, anchors, xywh=False, dim=1) |
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
|
|
def postprocess(preds: torch.Tensor, max_det: int, nc: int = 80): |
|
|
|
|
""" |
|
|
|
|
Post-processes the predictions obtained from a YOLOv10 model. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
preds (torch.Tensor): The predictions obtained from the model. It should have a shape of (batch_size, num_boxes, 4 + num_classes). |
|
|
|
|
max_det (int): The maximum number of detections to keep. |
|
|
|
|
nc (int, optional): The number of classes. Defaults to 80. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(torch.Tensor): The post-processed predictions with shape (batch_size, max_det, 6), |
|
|
|
|
including bounding boxes, scores and cls. |
|
|
|
|
""" |
|
|
|
|
assert 4 + nc == preds.shape[-1] |
|
|
|
|
boxes, scores = preds.split([4, nc], dim=-1) |
|
|
|
|
max_scores = scores.amax(dim=-1) |
|
|
|
|
max_scores, index = torch.topk(max_scores, min(max_det, max_scores.shape[1]), axis=-1) |
|
|
|
|
index = index.unsqueeze(-1) |
|
|
|
|
boxes = torch.gather(boxes, dim=1, index=index.repeat(1, 1, boxes.shape[-1])) |
|
|
|
|
scores = torch.gather(scores, dim=1, index=index.repeat(1, 1, scores.shape[-1])) |
|
|
|
|
|
|
|
|
|
# NOTE: simplify but result slightly lower mAP |
|
|
|
|
# scores, labels = scores.max(dim=-1) |
|
|
|
|
# return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1)], dim=-1) |
|
|
|
|
|
|
|
|
|
scores, index = torch.topk(scores.flatten(1), max_det, axis=-1) |
|
|
|
|
labels = index % nc |
|
|
|
|
index = index // nc |
|
|
|
|
boxes = boxes.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, boxes.shape[-1])) |
|
|
|
|
|
|
|
|
|
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Detect(nn.Module): |
|
|
|
|
def __init__(self, nc: int = 80, ch=()): |
|
|
|
|
""" |
|
|
|
|