mct-2.1.1
Laughing-q 8 months ago
parent 9d70ad09d0
commit 45663520ae
  1. 69
      ultralytics/nn/modules/head.py
  2. 3
      ultralytics/nn/tasks.py

@ -18,7 +18,7 @@ from .utils import bias_init_with_prob, linear_init
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect"
class DetectNEW(nn.Module):
class Detect(nn.Module):
"""YOLOv8 Detect head for detection models."""
dynamic = False # force grid reconstruction
@ -174,73 +174,6 @@ class DetectNEW(nn.Module):
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1)
class Detect(nn.Module):
def __init__(self, nc: int = 80, ch=()):
"""
Detection layer for YOLOv8.
Args:
nc (int): Number of classes.
ch (List[int]): List of channel values for detection layers.
"""
super().__init__()
self.nc = nc # number of classes
self.nl = len(ch) # number of detection layers
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = nc + self.reg_max * 4 # number of outputs per anchor
self.stride = torch.Tensor([8, 16, 32])
self.feat_sizes = torch.Tensor([80, 40, 20])
self.img_size = 640 # img size
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
)
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
anchors, strides = (x.transpose(0, 1) for x in make_anchors(self.feat_sizes, self.stride, 0.5))
strides = strides / self.img_size
anchors = anchors * strides
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.relu3 = nn.ReLU()
self.relu4 = nn.ReLU()
self.anchors = anchors
self.strides = strides
def forward(self, x):
shape = x[0].shape # BCHW
for i in range(self.nl):
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1)
y_cls = cls.sigmoid().transpose(1, 2)
dfl = self.dfl(box)
dfl = dfl * self.strides
# box decoding
lt, rb = dfl.chunk(2, 1)
y1 = self.relu1(self.anchors.unsqueeze(0)[:, 0, :] - lt[:, 0, :])
x1 = self.relu2(self.anchors.unsqueeze(0)[:, 1, :] - lt[:, 1, :])
y2 = self.relu3(self.anchors.unsqueeze(0)[:, 0, :] + rb[:, 0, :])
x2 = self.relu4(self.anchors.unsqueeze(0)[:, 1, :] + rb[:, 1, :])
y_bb = torch.stack((x1, y1, x2, y2), 1).transpose(1, 2)
return y_bb, y_cls
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
def decode_bboxes(self, bboxes, anchors):
"""Decode bounding boxes."""
return bboxes
class Segment(Detect):
"""YOLOv8 Segment head for segmentation models."""

@ -323,8 +323,7 @@ class DetectionModel(BaseModel):
return self.forward(x)["one2many"]
return self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
# m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
self.stride = m.stride
m.stride = torch.tensor([s / x.shape[-2] for x in _forward(torch.zeros(1, ch, s, s))]) # forward
m.bias_init() # only run once
else:
self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR

Loading…
Cancel
Save