clean up modules

clean-exp-bk
Laughing-q 5 months ago
parent 8fc4a065ef
commit d0c6bb3d82
  1. 18
      ultralytics/nn/modules/__init__.py
  2. 121
      ultralytics/nn/modules/block.py
  3. 185
      ultralytics/nn/modules/head.py
  4. 25
      ultralytics/nn/tasks.py

@ -47,15 +47,7 @@ from .block import (
ResNetLayer,
Silence,
C2f2,
C3f2,
C3F2,
C2k2,
C3k2,
C3K2,
C3n2,
C3s2,
C3k3,
C3m1,
SCDown,
C2fPSA,
C2PSA
@ -75,7 +67,7 @@ from .conv import (
RepConv,
SpatialAttention,
)
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect, Detect4
from .head import OBB, Classify, Detect, Pose, RTDETRDecoder, Segment, WorldDetect
from .transformer import (
AIFI,
MLP,
@ -117,15 +109,7 @@ __all__ = (
"C3",
"C2f",
"C2f2",
"C3f2",
"C3F2",
"C2k2",
"C3k2",
"C3K2",
"C3k3",
"C3s2",
"C3n2",
"C3m1",
"SCDown",
"C2fPSA",
"C2PSA",

@ -38,15 +38,7 @@ __all__ = (
"CBLinear",
"Silence",
"C2f2",
"C3f2",
"C3F2",
"C2k2",
"C3k2",
"C3K2",
"C3s2",
"C3n2",
"C3k3",
"C3m1",
"SCDown",
"C2fPSA",
"C2PSA",
@ -699,7 +691,7 @@ class CBFuse(nn.Module):
return out
# -------------------Experimental-------------------
# TODO: clean this
class C2f2(nn.Module):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
@ -747,44 +739,6 @@ class C3f(nn.Module):
return self.cv3(torch.cat(y, 1))
class C3f2(nn.Module):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, nk=2):
"""Initialize CSP bottleneck layer with two convolutions with arguments ch_in, ch_out, number, shortcut, groups,
expansion.
"""
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv((2 + n) * c_, c2, 1) # optional act=FReLU(c2)
self.m = nn.ModuleList(C3k(c_, c_, nk, shortcut, g) for _ in range(n))
def forward(self, x):
"""Forward pass through C2f layer."""
y = [self.cv2(x), self.cv1(x)]
y.extend(m(y[-1]) for m in self.m)
return self.cv3(torch.cat(y, 1))
class C3F2(C3f2):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e) # hidden channels
self.m = nn.ModuleList(C3f(c_, c_, 2, shortcut, g) for _ in range(n))
class C2k2(C2f2):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
self.m = nn.ModuleList(C2(self.c, self.c, 2, shortcut, g) for _ in range(n))
class C3k2(C2f2):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
@ -795,53 +749,6 @@ class C3k2(C2f2):
)
class C3k3(C2f2):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
self.m = nn.ModuleList(C3k(self.c, self.c, 3, shortcut, g) for _ in range(n))
class C3K2(C2f2):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
self.m = nn.ModuleList(C3K(self.c, self.c, 2, shortcut, g) for _ in range(n))
class C3m1(C2f2):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
self.m = nn.ModuleList(C3m(self.c, self.c, 1, shortcut, g) for _ in range(n))
class C3n2(C2f2):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
self.m = nn.ModuleList(C3n(self.c, self.c, shortcut, g) for _ in range(n))
class C3s2(C2f2):
"""Faster Implementation of CSP Bottleneck with 2 convolutions."""
def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
self.m = nn.ModuleList(C3s(self.c, self.c, n, shortcut, g) for _ in range(n))
class C3m(C3):
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e) # hidden channels
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(5, 5), e=1.0) for _ in range(n)))
class C3k(C3):
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5, k=3):
super().__init__(c1, c2, n, shortcut, g, e)
@ -850,32 +757,6 @@ class C3k(C3):
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(k, k), e=1.0) for _ in range(n)))
class C3K(C3):
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e) # hidden channels
# self.cv2 = Conv(c1, c_, 3, 1)
self.cv3 = Conv(2 * c_, c2, 3) # optional act=FReLU(c2)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(3, 3), e=1.0) for _ in range(n)))
class C3n(C3k):
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):
super().__init__(c1, c2, 2, shortcut, g, e)
def forward(self, x):
"""Forward pass through the CSP bottleneck with 2 convolutions."""
return self.cv3(torch.cat((self.m[0](self.cv1(x)), self.m[1](self.cv2(x))), 1))
class C3s(C3):
def __init__(self, c1, c2, n, shortcut=True, g=1, e=0.5):
super().__init__(c1, c2, n, shortcut, g, e)
c_ = int(c2 * e) # hidden channels
self.cv2 = Conv(c1, c_, 3, 1)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=(3, 1), e=1.0) for _ in range(n)))
class SCDown(nn.Module):
def __init__(self, c1, c2, k, s):
"""

@ -37,15 +37,6 @@ class Detect(nn.Module):
self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
)
# self.cv2 = nn.ModuleList(
# nn.Sequential(
# nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
# nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
# nn.Conv2d(c3, 4 * self.reg_max, 1),
# )
# for x in ch
# )
# self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.cv3 = nn.ModuleList(
nn.Sequential(
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
@ -104,182 +95,6 @@ class Detect(nn.Module):
return dist2bbox(bboxes, anchors, xywh=True, dim=1)
class Detect4(nn.Module):
"""YOLOv8 Detect head for detection models."""
dynamic = False # force grid reconstruction
export = False # export mode
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init
def __init__(self, nc=80, ch=()):
"""Initializes the YOLOv8 detection layer with specified number of classes and channels."""
super().__init__()
self.nc = nc # number of classes
self.nl = len(ch) # number of detection layers
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = nc + self.reg_max * 4 # number of outputs per anchor
self.stride = torch.zeros(self.nl) # strides computed during build
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
self.cv2 = nn.ModuleList(
nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
)
# self.cv2 = nn.ModuleList(
# nn.Sequential(
# nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
# nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
# nn.Conv2d(c3, 4 * self.reg_max, 1),
# )
# for x in ch
# )
# self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
self.cv3 = nn.ModuleList(
nn.Sequential(
nn.Sequential(DWConv(x, x, 3), Conv(x, c3, 1)),
nn.Sequential(DWConv(c3, c3, 3), Conv(c3, c3, 1)),
nn.Conv2d(c3, self.nc, 1),
)
for x in ch
)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
self.cv4 = nn.ModuleList(
nn.Sequential(
Conv(4 * self.reg_max, 16, 1),
nn.Conv2d(16, 1, 1),
# nn.Conv2d(4 * 5, 1, 1),
# nn.Conv2d(4 * 1, 1, 1),
nn.Sigmoid(),
)
for _ in ch
)
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
for i in range(self.nl):
box = self.cv2[i](x[i])
cls = self.cv3[i](x[i])
conf = self.cv4[i](box)
# N, C, H, W = box.size()
# prob = box.view(N, 4, self.reg_max, H, W).softmax(dim=2)
# prob_topk = prob.topk(4, dim=2)[0]
# prob_topk = torch.cat([prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2).view(N, -1, H, W)
# conf = self.cv4[i](prob_topk)
# N, C, H, W = box.size()
# prob = box.view(N, 4, self.reg_max, H, W).softmax(dim=2)
# prob_max = prob.amax(dim=2)
# prob_mean = prob.mean(dim=2)
# conf = self.cv4[i](torch.cat([prob_max, prob_mean], dim=1))
# # conf = self.cv4[i](prob_max)
x[i] = torch.cat((box, cls * conf), 1)
if self.training: # Training path
return x
# Inference path
shape = x[0].shape # BCHW
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
box = x_cat[:, : self.reg_max * 4]
cls = x_cat[:, self.reg_max * 4 :]
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
if self.export and self.format in {"tflite", "edgetpu"}:
# Precompute normalization factor to increase numerical stability
# See https://github.com/ultralytics/ultralytics/issues/7371
grid_h = shape[2]
grid_w = shape[3]
grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
norm = self.strides / (self.stride[0] * grid_size)
dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
else:
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
for a, b, c, s in zip(m.cv2, m.cv3, m.cv4, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
c[-2].bias.data[:] = 2.0 # cls (.01 objects, 80 classes, 640 img)
def decode_bboxes(self, bboxes, anchors):
"""Decode bounding boxes."""
return dist2bbox(bboxes, anchors, xywh=True, dim=1)
class DetectNew(nn.Module):
"""YOLOv8 Detect head for detection models."""
dynamic = False # force grid reconstruction
export = False # export mode
shape = None
anchors = torch.empty(0) # init
strides = torch.empty(0) # init
def __init__(self, nc=80, ch=()): # detection layer
super().__init__()
self.nc = nc # number of classes
self.nl = len(ch) # number of detection layers
self.reg_max = 16 # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
self.no = nc + self.reg_max * 4 # number of outputs per anchor
self.stride = torch.zeros(self.nl) # strides computed during build
c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100)) # channels
c23 = (c2 + c3) // 2 * 2
self.cv23 = nn.ModuleList(
nn.Sequential(Conv(x, c23, 3, g=2), Conv(c23, c23, 3, g=2), nn.Conv2d(c23, self.no, 1)) for x in ch
)
# self.cv23 = nn.ModuleList(
# nn.Sequential(Conv(x, c23, 3, g=2),
# Conv(c23, c23, 3, g=2),
# ConvSplit(c23, c2_list=[4 * self.reg_max, self.nc], k=(1, 1))) for x in ch)
self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()
def forward(self, x):
"""Concatenates and returns predicted bounding boxes and class probabilities."""
shape = x[0].shape # BCHW
for i in range(self.nl):
x[i] = self.cv23[i](x[i])
if self.training:
return x
elif self.dynamic or self.shape != shape:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.shape = shape
x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
box = x_cat[:, : self.reg_max * 4]
cls = x_cat[:, self.reg_max * 4 :]
else:
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
n = 4 * m.reg_max # number of box channels
for a, s in zip(m.cv23, m.stride): # from
a[-1].bias.data[:n] = 0.0 # box
a[-1].bias.data[n : n + m.nc] = math.log(
5 / m.nc / (640 / s) ** 2
) # cls (.01 objects, 80 classes, 640 img)
class Segment(Detect):
"""YOLOv8 Segment head for segmentation models."""

@ -25,15 +25,7 @@ from ultralytics.nn.modules import (
BottleneckCSP,
C2f,
C2f2,
C3F2,
C3f2,
C2k2,
C3k2,
C3s2,
C3n2,
C3K2,
C3m1,
C3k3,
C2fAttn,
C3Ghost,
C3x,
@ -45,7 +37,6 @@ from ultralytics.nn.modules import (
Conv2,
ConvTranspose,
Detect,
Detect4,
DWConv,
DWConvTranspose2d,
Focus,
@ -245,7 +236,7 @@ class BaseModel(nn.Module):
"""
self = super()._apply(fn)
m = self.model[-1] # Detect()
if isinstance(m, (Detect, Detect4)): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
m.stride = fn(m.stride)
m.anchors = fn(m.anchors)
m.strides = fn(m.strides)
@ -304,7 +295,7 @@ class DetectionModel(BaseModel):
# Build strides
m = self.model[-1] # Detect()
if isinstance(m, (Detect, Detect4)): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
if isinstance(m, Detect): # includes all Detect subclasses like Segment, Pose, OBB, WorldDetect
s = 256 # 2x min stride
m.inplace = self.inplace
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose, OBB)) else self.forward(x)
@ -891,15 +882,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
C2,
C2f,
C2f2,
C3f2,
C3F2,
C2k2,
C3k2,
C3n2,
C3s2,
C3K2,
C3k3,
C3m1,
RepNCSPELAN4,
ADown,
SCDown,
@ -923,7 +906,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
) # num heads
args = [c1, c2, *args[1:]]
if m in (BottleneckCSP, C1, C2, C2f, C2f2, C3f2, C3F2, C2k2, C3s2, C3n2, C3k2, C3K2, C3m1, C3k3, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fPSA, C2PSA):
if m in (BottleneckCSP, C1, C2, C2f, C2f2, C3k2, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3, C2fPSA, C2PSA):
args.insert(2, n) # number of repeats
n = 1
if m is C3k2 and max_channels == 512: # for M/L/X sizes
@ -942,7 +925,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
args = [ch[f]]
elif m is Concat:
c2 = sum(ch[x] for x in f)
elif m in {Detect, Detect4, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn}:
elif m in {Detect, WorldDetect, Segment, Pose, OBB, ImagePoolingAttn}:
args.append([ch[x] for x in f])
if m is Segment:
args[2] = make_divisible(min(args[2], max_channels) * width, 8)

Loading…
Cancel
Save