diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 60911e779f..4bc1fa25e7 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -28,6 +28,7 @@ class Detect(nn.Module): shape = None anchors = torch.empty(0) # init strides = torch.empty(0) # init + legacy = False # backward compatibility for v3/v5/v8/v9 models def __init__(self, nc=80, ch=()): """Initializes the YOLO detection layer with specified number of classes and channels.""" @@ -41,13 +42,17 @@ class Detect(nn.Module): self.cv2 = nn.ModuleList( nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch ) - self.cv3 = nn.ModuleList( - nn.Sequential( - nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)), - nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)), - nn.Conv2d(c3, self.nc, 1), + self.cv3 = ( + nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch) + if self.legacy + else nn.ModuleList( + nn.Sequential( + nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)), + nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)), + nn.Conv2d(c3, self.nc, 1), + ) + for x in ch ) - for x in ch ) self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity() diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 12de1cfbf6..1e69a8f25c 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -936,6 +936,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) import ast # Args + legacy = True # backward compatibility for v3/v5/v8/v9 models max_channels = float("inf") nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales")) depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape")) @@ -1027,8 +1028,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) }: args.insert(2, n) # number of repeats n = 1 - if m is C3k2 and scale in "mlx": # for M/L/X sizes - args[3] = True + if m is C3k2: # for M/L/X sizes + legacy = False + if scale in "mlx": + args[3] = True elif m is AIFI: args = [ch[f], *args] elif m in {HGStem, HGBlock}: @@ -1047,6 +1050,8 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3) args.append([ch[x] for x in f]) if m is Segment: args[2] = make_divisible(min(args[2], max_channels) * width, 8) + if m in {Detect, Segment, Pose, OBB}: + m.legacy = legacy elif m is RTDETRDecoder: # special case, channels arg must be passed in index 1 args.insert(1, [ch[x] for x in f]) elif m is CBLinear: