From 0ad7e7a0f21bd7ad38cfa97db39de50e52df3617 Mon Sep 17 00:00:00 2001 From: Laughing-q <1185102784@qq.com> Date: Fri, 7 Jun 2024 17:21:57 +0800 Subject: [PATCH] light cls head --- ultralytics/nn/modules/head.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 1cf4cf4c7..28c9cbb58 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -9,7 +9,7 @@ from torch.nn.init import constant_, xavier_uniform_ from ultralytics.utils.tal import TORCH_1_10, dist2bbox, dist2rbox, make_anchors from .block import DFL, BNContrastiveHead, ContrastiveHead, Proto -from .conv import Conv +from .conv import Conv, DWConv from .transformer import MLP, DeformableTransformerDecoder, DeformableTransformerDecoderLayer from .utils import bias_init_with_prob, linear_init @@ -37,7 +37,15 @@ class Detect(nn.Module): self.cv2 = nn.ModuleList( nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch ) - self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) + # self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) + self.cv3 = nn.ModuleList( + nn.Sequential( + nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)), + nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)), + nn.Conv2d(c3, self.nc, 1), + ) + for x in ch + ) self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() def forward(self, x):