|
|
|
@ -41,14 +41,7 @@ class Detect(nn.Module): |
|
|
|
|
self.cv2 = nn.ModuleList( |
|
|
|
|
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch |
|
|
|
|
) |
|
|
|
|
self.cv3 = nn.ModuleList( |
|
|
|
|
nn.Sequential( |
|
|
|
|
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)), |
|
|
|
|
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)), |
|
|
|
|
nn.Conv2d(c3, self.nc, 1), |
|
|
|
|
) |
|
|
|
|
for x in ch |
|
|
|
|
) |
|
|
|
|
self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) |
|
|
|
|
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() |
|
|
|
|
|
|
|
|
|
if self.end2end: |
|
|
|
@ -595,3 +588,35 @@ class v10Detect(Detect): |
|
|
|
|
for x in ch |
|
|
|
|
) |
|
|
|
|
self.one2one_cv3 = copy.deepcopy(self.cv3) |
|
|
|
|
|
|
|
|
|
class v11Detect(Detect): |
|
|
|
|
""" |
|
|
|
|
v10 Detection head from https://arxiv.org/pdf/2405.14458. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
nc (int): Number of classes. |
|
|
|
|
ch (tuple): Tuple of channel sizes. |
|
|
|
|
|
|
|
|
|
Attributes: |
|
|
|
|
max_det (int): Maximum number of detections. |
|
|
|
|
|
|
|
|
|
Methods: |
|
|
|
|
__init__(self, nc=80, ch=()): Initializes the v10Detect object. |
|
|
|
|
forward(self, x): Performs forward pass of the v10Detect module. |
|
|
|
|
bias_init(self): Initializes biases of the Detect module. |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, nc=80, ch=()): |
|
|
|
|
"""Initializes the v10Detect object with the specified number of classes and input channels.""" |
|
|
|
|
super().__init__(nc, ch) |
|
|
|
|
c3 = max(ch[0], min(self.nc, 100)) # channels |
|
|
|
|
# Light cls head |
|
|
|
|
self.cv3 = nn.ModuleList( |
|
|
|
|
nn.Sequential( |
|
|
|
|
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)), |
|
|
|
|
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)), |
|
|
|
|
nn.Conv2d(c3, self.nc, 1), |
|
|
|
|
) |
|
|
|
|
for x in ch |
|
|
|
|
) |