|
|
|
@ -28,6 +28,7 @@ class Detect(nn.Module): |
|
|
|
|
shape = None |
|
|
|
|
anchors = torch.empty(0) # init |
|
|
|
|
strides = torch.empty(0) # init |
|
|
|
|
legacy = False # backward compatibility for v3/v5/v8/v9 models |
|
|
|
|
|
|
|
|
|
def __init__(self, nc=80, ch=()): |
|
|
|
|
"""Initializes the YOLO detection layer with specified number of classes and channels.""" |
|
|
|
@ -41,7 +42,10 @@ class Detect(nn.Module): |
|
|
|
|
self.cv2 = nn.ModuleList( |
|
|
|
|
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch |
|
|
|
|
) |
|
|
|
|
self.cv3 = nn.ModuleList( |
|
|
|
|
self.cv3 = ( |
|
|
|
|
nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) |
|
|
|
|
if self.legacy |
|
|
|
|
else nn.ModuleList( |
|
|
|
|
nn.Sequential( |
|
|
|
|
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)), |
|
|
|
|
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)), |
|
|
|
@ -49,6 +53,7 @@ class Detect(nn.Module): |
|
|
|
|
) |
|
|
|
|
for x in ch |
|
|
|
|
) |
|
|
|
|
) |
|
|
|
|
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() |
|
|
|
|
|
|
|
|
|
if self.end2end: |
|
|
|
|