|
|
|
@ -18,7 +18,7 @@ from .utils import bias_init_with_prob, linear_init |
|
|
|
|
__all__ = "Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder", "v10Detect" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DetectNEW(nn.Module): |
|
|
|
|
class Detect(nn.Module): |
|
|
|
|
"""YOLOv8 Detect head for detection models.""" |
|
|
|
|
|
|
|
|
|
dynamic = False # force grid reconstruction |
|
|
|
@ -174,73 +174,6 @@ class DetectNEW(nn.Module): |
|
|
|
|
return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Detect(nn.Module): |
|
|
|
|
def __init__(self, nc: int = 80, ch=()): |
|
|
|
|
""" |
|
|
|
|
Detection layer for YOLOv8. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
nc (int): Number of classes. |
|
|
|
|
ch (List[int]): List of channel values for detection layers. |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
super().__init__() |
|
|
|
|
self.nc = nc # number of classes |
|
|
|
|
self.nl = len(ch) # number of detection layers |
|
|
|
|
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x) |
|
|
|
|
self.no = nc + self.reg_max * 4 # number of outputs per anchor |
|
|
|
|
self.stride = torch.Tensor([8, 16, 32]) |
|
|
|
|
self.feat_sizes = torch.Tensor([80, 40, 20]) |
|
|
|
|
self.img_size = 640 # img size |
|
|
|
|
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels |
|
|
|
|
self.cv2 = nn.ModuleList( |
|
|
|
|
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch |
|
|
|
|
) |
|
|
|
|
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) |
|
|
|
|
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() |
|
|
|
|
anchors, strides = (x.transpose(0, 1) for x in make_anchors(self.feat_sizes, self.stride, 0.5)) |
|
|
|
|
strides = strides / self.img_size |
|
|
|
|
anchors = anchors * strides |
|
|
|
|
self.relu1 = nn.ReLU() |
|
|
|
|
self.relu2 = nn.ReLU() |
|
|
|
|
self.relu3 = nn.ReLU() |
|
|
|
|
self.relu4 = nn.ReLU() |
|
|
|
|
|
|
|
|
|
self.anchors = anchors |
|
|
|
|
self.strides = strides |
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
|
shape = x[0].shape # BCHW |
|
|
|
|
for i in range(self.nl): |
|
|
|
|
x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1) |
|
|
|
|
box, cls = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2).split((self.reg_max * 4, self.nc), 1) |
|
|
|
|
|
|
|
|
|
y_cls = cls.sigmoid().transpose(1, 2) |
|
|
|
|
|
|
|
|
|
dfl = self.dfl(box) |
|
|
|
|
dfl = dfl * self.strides |
|
|
|
|
|
|
|
|
|
# box decoding |
|
|
|
|
lt, rb = dfl.chunk(2, 1) |
|
|
|
|
y1 = self.relu1(self.anchors.unsqueeze(0)[:, 0, :] - lt[:, 0, :]) |
|
|
|
|
x1 = self.relu2(self.anchors.unsqueeze(0)[:, 1, :] - lt[:, 1, :]) |
|
|
|
|
y2 = self.relu3(self.anchors.unsqueeze(0)[:, 0, :] + rb[:, 0, :]) |
|
|
|
|
x2 = self.relu4(self.anchors.unsqueeze(0)[:, 1, :] + rb[:, 1, :]) |
|
|
|
|
y_bb = torch.stack((x1, y1, x2, y2), 1).transpose(1, 2) |
|
|
|
|
return y_bb, y_cls |
|
|
|
|
|
|
|
|
|
def bias_init(self): |
|
|
|
|
"""Initialize Detect() biases, WARNING: requires stride availability.""" |
|
|
|
|
m = self # self.model[-1] # Detect() module |
|
|
|
|
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from |
|
|
|
|
a[-1].bias.data[:] = 1.0 # box |
|
|
|
|
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img) |
|
|
|
|
|
|
|
|
|
def decode_bboxes(self, bboxes, anchors): |
|
|
|
|
"""Decode bounding boxes.""" |
|
|
|
|
return bboxes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Segment(Detect): |
|
|
|
|
"""YOLOv8 Segment head for segmentation models.""" |
|
|
|
|
|
|
|
|
|